diff --git a/neutone_midi_sdk/README.md b/neutone_midi_sdk/README.md new file mode 100644 index 0000000..611f50f --- /dev/null +++ b/neutone_midi_sdk/README.md @@ -0,0 +1,109 @@ +# Neutone-MIDI SDK + +The goal of this SDK is to provide an environment where researchers, musicians and engineers +can quickly 'wrap' an existing machine-learning model for symbolic music tasks into a format that is deployed in a real-time +plugin for DAW's. + +There are two guides to help with this process: +1. model_training_guide: Details the setup you should follow +in the training pipeline to ensure your model will be compatible +with the SDK +2. model_preparation_guide: how to export your model + +Once your model is trained and serialized with the above +methods, you will find the remaining instructions in this Readme +to 'wrap' it for deployment in the Neunote-MIDI plugin. + +We have designed the SDK to work in conjunction with [MIDITok](https://github.com/Natooz/MidiTok), +which lets you +tokenise an entire collection of MIDI files in a few easy commands. The SDK can convert the MIDI +data in DAW's to and from this format, allowing your model to interact with the same data format that it was trained on. + + +# Wrapping your model + +Once you have a serialized a model trained on a supported tokenization format, it's time to wrap it! + +**First, load your vocab and config files** +```angular2html +import torch +import json +from neunote_SDK import MidiToMidiBase +from data_preparation import prepare_token_data +from tokenization import TokenData + +with open(vocab_file_path, 'r') as fp: + vocab = json.load(fp) +with open(config_file_path, 'r') as fp: + config = json.load(fp) +tokenizer_type = config["tokenization"] +``` + +Load your serialized model: +``` +remi_Model = torch.jit.load("path_to_model") +``` + +Wrap it: +```angular2html +tokenizer_type = "REMI" +tokenizer_data: TokenData = prepare_token_data(tokenizer_type, vocab, config) +wrapped_model = MidiToMidiBase(model=remi_Model(), + vocab=tokenizer.vocab, + tokenizer_type=tokenizer_type, + tokenizer_data=tokenizer_data) +scripted_model = torch.jit.script(wrapped_model) +scripted_model.save('REMI_Model.pt') +``` +And...that's it! Your model is now ready to deploy in the Neutone-MIDI Plugin. + + +# SDK Components +### Neutone-MIDI SDK: + +Provides the base wrapper for a MIDI-to-MIDI model, which is saved as a pytorch scripted .pt file + + +### Data Preparation +Each tokenization method has a particular set of quantized values that are available, +related to pitch, timing, velocity, etc. Because sequence length often has a large impact +on computational time, each model can use a slightly different granularity. To maintain efficiency, +it is helpful for the scripted model to have lists already identifying these available values. + +For example, if a midi message comes in with ``velocity=43`` and the available values are +``[20, 40, 60, 80, 100, 120]`` then the tokenizer can quickly round the incoming velocity to the +nearest value of ``40``. + +Given the original vocab json and the type of tokenization method, the data preparation utility +will return a tuple of dictionaries of lists of the relevant data values. As this can be accomplished during the +wrapping procedure, it saves the plugin the necessity to calculate available values on each forward pass. + + +### MIDI Data Format + +Input tensor will be dim of (x, 4) where x = number of midi messages. Each midi message will have type: + +``{type, value, velocity, timestep}`` + +Current types: +``` +0.0 = note on + +1.0 = note off +``` + +``{0.0, 64.0, 90.0, 2.5}`` = note on, pitch of 64, velocity 90, at beat 2.5 + +Every tokenization method is expecting this as an input, and will return it as an output + +**Timing**: + +Within the C++ environment, timing is always expressed as **PPQ**, which is a float value in relation to quarter notes. +Continuing off the above example '2.5' means an eight note (.5) after the second quarter note (2). MIDI can communicate time in a number of formats +in varying resolutions; but the input and output must always adhere to this. as it determines where the plugin places MIDI events within the +buffer. + +If, for example, your model uses a 'ticks-per-beat' system with a resolution of 96 ticks-per-quarter, +then it is the job of the tokeniser to convert from the PPQ to ticks system. All included tokenisation +methods already take care of this functionality. + diff --git a/neutone_midi_sdk/__init__.py b/neutone_midi_sdk/__init__.py new file mode 100644 index 0000000..c975eb7 --- /dev/null +++ b/neutone_midi_sdk/__init__.py @@ -0,0 +1,6 @@ +from .core import NeutoneMIDIModel +from .tokenization import * +from .parameter import * +from .data_preparation import * +from .constants import * +from .neutoneMIDI_SDK import * diff --git a/neutone_midi_sdk/constants.py b/neutone_midi_sdk/constants.py new file mode 100644 index 0000000..e7df62f --- /dev/null +++ b/neutone_midi_sdk/constants.py @@ -0,0 +1,6 @@ +SDK_VERSION = "0.1.1" + +MAX_N_NUMERICAL_PARAMS = 4 +MAX_N_TENSOR_PARAMS = 1 +SUPPORTED_TOKENIZATIONS = ["MIDILike", "TSD", "REMI", "HVO", "HVO_taps", "Custom"] +MAX_N_CATEGORICAL_VALUES = 20 \ No newline at end of file diff --git a/neutone_midi_sdk/core.py b/neutone_midi_sdk/core.py new file mode 100644 index 0000000..696255e --- /dev/null +++ b/neutone_midi_sdk/core.py @@ -0,0 +1,199 @@ +import torch as tr +from torch import nn, Tensor +from typing import List, Dict, Tuple, Union +from abc import abstractmethod +from neutone_midi_sdk.tokenization import TokenData +from neutone_midi_sdk.parameter import NeutoneParameter +import neutone_midi_sdk.constants as constants + + +class NeutoneMIDIModel(tr.nn.Module): + def __init__(self, + model: tr.nn.Module, + vocab: Dict[str, int], + tokenizer_type: str, + tokenizer_data: TokenData): + + super().__init__() + self.MAX_N_NUMERICAL_PARAMS = constants.MAX_N_NUMERICAL_PARAMS + self.MAX_N_TENSOR_PARAMS = constants.MAX_N_TENSOR_PARAMS + self.SDK_VERSION = constants.SDK_VERSION + self.n_neutone_parameters = len(self.get_neutone_parameters()) + + # Allocate default numerical params to prevent dynamic allocations later + numerical_default_param_vals = self._get_numerical_default_param_values() + assert len(numerical_default_param_vals) <= self.MAX_N_NUMERICAL_PARAMS, ( + f"Number of default numerical parameter values ({len(numerical_default_param_vals)}) " + f"exceeds the maximum allowed ({self.MAX_N_NUMERICAL_PARAMS})." + ) + numerical_default_param_values_t = tr.tensor([v for _, v in numerical_default_param_vals]) + # Ensure number of parameters is within the maximum allowed + self.n_numerical_neutone_parameters = len(numerical_default_param_vals) + assert self.n_numerical_neutone_parameters <= self.MAX_N_NUMERICAL_PARAMS + # Ensure parameter names are unique + assert len(set([p.name for p in self.get_neutone_parameters()])) == len( + self.get_neutone_parameters() + ) + self.register_buffer("tensor_default_param_values", numerical_default_param_values_t.unsqueeze(-1)) + + # Allocate default tensor params to prevent dynamic allocations later + tensor_default_param_vals = self._get_tensor_default_param_values() + assert len(tensor_default_param_vals) <= self.MAX_N_TENSOR_PARAMS, ( + f"Number of default tensor parameter values ({len(numerical_default_param_vals)}) " + f"exceeds the maximum allowed ({self.MAX_N_TENSOR_PARAMS})." + ) + # TODO(nic): this assumes a common dimension for all tensor parameters + tensor_default_param_values_t = tr.cat([v for _, v in tensor_default_param_vals]) + self.register_buffer("numerical_default_param_values", tensor_default_param_values_t.unsqueeze(-1)) + + # Save parameter metadata + self.neutone_parameters_metadata = { + p.name: p.to_metadata_dict() + for idx, p in enumerate(self.get_neutone_parameters()) + } + + # Allocate remapped params dictionary to prevent dynamic allocations later + self.remapped_params = { + name: tr.tensor([val]) + for name, val in numerical_default_param_vals + } + self.remapped_params.update( + { + name: val + for name, val in tensor_default_param_vals + } + ) + self.default_param_values = self.remapped_params + + # Save parameter information + self.neutone_parameter_names = [p.name for p in self.get_neutone_parameters()] + # TODO(nic): remove from here once plugin metadata parsing is implemented + self.neutone_parameter_descriptions = [ + p.description for p in self.get_neutone_parameters() + ] + self.neutone_parameter_used = [p.used for p in self.get_neutone_parameters()] + self.neutone_parameter_types = [ + p.type.value for p in self.get_neutone_parameters() + ] + + # instantiate model + model.eval() + self.model = model + + # Setup tokenization methods + assert tokenizer_type in constants.SUPPORTED_TOKENIZATIONS, \ + f"{tokenizer_type} not a recognized tokenization format." + tokenizer_data = generate_fake_token_data() if tokenizer_data is None else tokenizer_data + vocab = {"v": 0} if vocab is None else vocab + self.midi_to_token_vocab = vocab + self.token_to_midi_vocab = {v: k for k, v in vocab.items()} + self.tokenizer_type = tokenizer_type + self.tokenizer_data: TokenData = TokenData(tokenizer_data.strings, tokenizer_data.floats, tokenizer_data.ints) + + @abstractmethod + def _get_numerical_default_param_values( + self, + ) -> List[Tuple[str, Union[float, int]]]: + """ + Returns a list of tuples containing the name and default value of each + numerical (float or int) parameter. + This should not be overwritten by SDK users. + """ + pass + + @abstractmethod + def _get_tensor_default_param_values( + self, + ) -> List[Tuple[str, Union[float, int]]]: + """ + Returns a list of tuples containing the name and default value of each + tensor parameter. + This should not be overwritten by SDK users. + """ + pass + + @abstractmethod + def get_model_name(self) -> str: + """ + Set the model name + """ + pass + + @abstractmethod + def get_model_authors(self) -> List[str]: + """ + Used to set the model authors. This will be displayed on both the + website and the plugin. + + Should reflect the name of the people that developed the wrapper + of the model using the SDK. Can be different from the authors of + the original model. + + Maximum of 5 authors. + """ + pass + + @abstractmethod + def get_model_short_description(self) -> str: + """ + Used to set the model short description. This will be displayed on both + the website and the plugin. + + This is meant to be seen by the audio creators and should give a summary + of what the model does. + + Maximum of 150 characters. + """ + pass + + def get_neutone_parameters(self) -> List[NeutoneParameter]: + return [] + + @tr.jit.export + def get_neutone_parameters_metadata(self) -> Dict[str, Dict[str, str]]: + """ + Returns the metadata of the parameters as a string dictionary of string + dictionaries. + """ + return self.neutone_parameters_metadata + + @tr.jit.export + def get_default_param_values(self) -> Dict[str, Tensor]: + """ + Returns the default parameter values as a tensor of shape + (N_DEFAULT_PARAM_VALUES, 1). + """ + return self.default_param_values + + @tr.jit.export + def get_default_param_names(self) -> List[str]: + # TODO(nic): remove this once plugin metadata parsing is implemented + return self.neutone_parameter_names + + @tr.jit.export + def get_default_param_descriptions(self) -> List[str]: + # TODO(nic): remove this once plugin metadata parsing is implemented + return self.neutone_parameter_descriptions + + @tr.jit.export + def get_default_param_types(self) -> List[str]: + # TODO(nic): remove this once plugin metadata parsing is implemented + return self.neutone_parameter_types + + @tr.jit.export + def get_default_param_used(self) -> List[bool]: + # TODO(nic): remove this once plugin metadata parsing is implemented + return self.neutone_parameter_used + + def prepare_for_inference(self) -> None: + self.model.eval() + self.eval() + + +# Todo: Would like to deprecate this method, it is used in "HVO" format where there is no TokenData necessary +def generate_fake_token_data(): + token_strings: Dict[str, List[str]] = {"value": ["value"]} + token_floats: Dict[str, List[float]] = {"value": [0.0]} + token_ints: Dict[str, List[int]] = {"value": [0]} + token_data: TokenData = TokenData(token_strings, token_floats, token_ints) + return token_data diff --git a/neutone_midi_sdk/data_preparation.py b/neutone_midi_sdk/data_preparation.py new file mode 100644 index 0000000..1c4b227 --- /dev/null +++ b/neutone_midi_sdk/data_preparation.py @@ -0,0 +1,183 @@ +from typing import Dict, List, Tuple +from neutone_midi_sdk.tokenization import TokenData + + +def prepare_token_data(token_type: str, + midi_to_token_vocab: Dict[str, int], + config) \ + -> TokenData: + """ + While the forward method is tokenizing data, it requires knowledge of the 'available' + values. For example, if velocity is constrained to 32 values (instead of 127), it needs + a list of these available values. To save on computational time within the plugin environment, + this function is run once during model wrapping, to extract all necessary values. They are + returned in a Dictionary of Lists, each of which can quickly be accessed in real-time. + + Each tokenization has its own set of unique token types. + MIDILike: [TimeShift, Pitch, Velocity] + """ + + assert token_type in ["MIDILike", "TSD", "REMI", "Custom"], "Incorrect tokenization method specified." + + if token_type == "MIDILike": + token_data: TokenData = prepare_MIDILike_data(midi_to_token_vocab) + return token_data + + elif token_type == "TSD": + token_data: TokenData = prepare_TSD_data(midi_to_token_vocab) + return token_data + + elif token_type == "REMI": + token_data: TokenData = prepare_REMI_data(midi_to_token_vocab, config) + return token_data + + else: + print("incorrect tokenization type specified, recommend retrying with a different value specified") + token_data: TokenData = prepare_MIDILike_data(midi_to_token_vocab) + return token_data + + +##### Utility Functions ##### +def extract_timing_tokens(tokens: List[str], identifier: str) -> Tuple[List[str], List[float]]: + """ + Extract information from tokens related to shifts in time, i.e. TimeShift and Duration. + They must follow the format of "Type_beat.subdivision.granularity", i.e. + "TimeShift_1.2.4" = Timeshift of 1 beat and 2 samples of 4-per-beat (16ths). + The tokenizer needs both the string version (for tokenizing), and + values that convert this to float timing in MIDI format, i.e. 1.2.4 = 1.5 (1.5 quarter notes) + Returns a tuple containing lists of both string and float formats. + """ + timing_token_strings: List[str] = [] + for token in tokens: + if identifier in token: + timing_token_strings.append(token) + + reduced_tokens: List[str] = [token.split("_")[1] for token in timing_token_strings] + timing_token_floats: List[float] = [] + for value in reduced_tokens: + number = int(value.split(".")[0]) + int(value.split(".")[1]) * (1 / int(value.split(".")[2])) + timing_token_floats.append(number) + + output: Tuple[List[str], List[float]] = (timing_token_strings, timing_token_floats) + return output + +def extract_timing_tokens(tokens: List[str], identifier: str) -> Tuple[List[str], List[float]]: + """ + Extract information from tokens related to shifts in time, i.e. TimeShift and Duration. + They must follow the format of "Type_beat.subdivision.granularity", i.e. + "TimeShift_1.2.4" = Timeshift of 1 beat and 2 samples of 4-per-beat (16ths). + The tokenizer needs both the string version (for tokenizing), and + values that convert this to float timing in MIDI format, i.e. 1.2.4 = 1.5 (1.5 quarter notes) + Returns a tuple containing lists of both string and float formats. + """ + timing_token_strings: List[str] = [] + for token in tokens: + if identifier in token: + timing_token_strings.append(token) + + reduced_tokens: List[str] = [token.split("_")[1] for token in timing_token_strings] + timing_token_floats: List[float] = [] + for value in reduced_tokens: + number = int(value.split(".")[0]) + int(value.split(".")[1]) * (1 / int(value.split(".")[2])) + timing_token_floats.append(number) + + output: Tuple[List[str], List[float]] = (timing_token_strings, timing_token_floats) + return output + + +def extract_value_tokens(tokens: List[str], identifier: str) -> List[int]: + """ + Designed to extract any token that has a string and an integer, i.e. + ["Pitch_23"] or ["Velocity_99"]. As this is a common format throughout Miditok, + the function can be called on a variety of use cases. + """ + values: List[int] = [] + for token in tokens: + if identifier in token: + values.append(int(token.split("_")[1])) + + return values + + +## Tokenization Preparation Functions ## +def prepare_MIDILike_data(midi_to_token_vocab: Dict[str, int]) \ + -> TokenData: + + tokens = [key for key in midi_to_token_vocab.keys()] + + # Get lists of available values + timeshift_tokens: Tuple[List[str], List[float]] = extract_timing_tokens(tokens, identifier="TimeShift") + pitch_values: List[int] = extract_value_tokens(tokens, identifier="NoteOn") + velocity_values: List[int] = extract_value_tokens(tokens, identifier="Velocity") + + token_strings: Dict[str, List[str]] = {"timeshift_strings": timeshift_tokens[0]} + token_floats: Dict[str, List[float]] = {"timeshift_floats": timeshift_tokens[1]} + token_ints: Dict[str, List[int]] = {"pitch_values": pitch_values, "velocity_values": velocity_values} + + token_data: TokenData = TokenData(token_strings, token_floats, token_ints) + + return token_data + +def prepare_TSD_data(midi_to_token_vocab: Dict[str, int]) \ + -> TokenData: + + tokens = [key for key in midi_to_token_vocab.keys()] + + # Get lists of available values + timeshift_tokens: Tuple[List[str], List[float]] = extract_timing_tokens(tokens, identifier="TimeShift") + duration_tokens: Tuple[List[str], List[float]] = extract_timing_tokens(tokens, identifier="Duration") + pitch_values: List[int] = extract_value_tokens(tokens, identifier="Pitch") + velocity_values: List[int] = extract_value_tokens(tokens, identifier="Velocity") + + token_strings: Dict[str, List[str]] = {"timeshift_strings": timeshift_tokens[0], + "duration_strings": duration_tokens[0]} + + token_floats: Dict[str, List[float]] = {"timeshift_floats": timeshift_tokens[1], + "duration_floats": duration_tokens[1]} + + token_ints: Dict[str, List[int]] = {"pitch_values": pitch_values, + "velocity_values": velocity_values} + + token_data: TokenData = TokenData(token_strings, token_floats, token_ints) + + return token_data + + +def prepare_REMI_data(midi_to_token_vocab: Dict[str, int], config) \ + -> TokenData: + + tokens = [key for key in midi_to_token_vocab.keys()] + + # Determine the granularity of the "Position" tokens, defined by the (0_N: res) beat res in Miditok config file + for k, v in config["beat_res"].items(): + if "0" in (k): + pos_granularity: List[float] = [float(1/v)] + break + + position_values: List[int] = extract_value_tokens(tokens, identifier="Position") + position_floats: List[float] = list() + position_strings: List[str] = list() + for token in tokens: + if "Position" in token: + position_strings.append(token) + value = int(token.split("_")[1]) + position_floats.append(float(value * pos_granularity[0])) + + pitch_values: List[int] = extract_value_tokens(tokens, identifier="Pitch") + velocity_values: List[int] = extract_value_tokens(tokens, identifier="Velocity") + duration_tokens: Tuple[List[str], List[float]] = extract_timing_tokens(tokens, identifier="Duration") + + token_strings: Dict[str, List[str]] = {"position_strings": position_strings, + "duration_strings": duration_tokens[0]} + + token_floats: Dict[str, List[float]] = {"pos_granularity": pos_granularity, + "position_floats": position_floats, + "duration_floats": duration_tokens[1]} + + token_ints: Dict[str, List[int]] = {"pitch_values": pitch_values, + "velocity_values": velocity_values, + "pos_values": position_values} + + token_data: TokenData = TokenData(token_strings, token_floats, token_ints) + + return token_data \ No newline at end of file diff --git a/neutone_midi_sdk/examples/remi/config.json b/neutone_midi_sdk/examples/remi/config.json new file mode 100644 index 0000000..2fc19b9 --- /dev/null +++ b/neutone_midi_sdk/examples/remi/config.json @@ -0,0 +1,37 @@ +{ + "pitch_range": [ + 21, + 109 + ], + "beat_res": { + "0_16": 4, + "16_24": 2 + }, + "_nb_velocities": 32, + "additional_tokens": { + "Chord": false, + "Rest": false, + "Tempo": false, + "Program": false, + "TimeSignature": false, + "rest_range": [ + 2, + 8 + ], + "nb_tempos": 32, + "tempo_range": [ + 40, + 250 + ] + }, + "special_tokens": [ + "PAD", + "BOS", + "EOS", + "MASK" + ], + "unique_track": false, + "has_bpe": false, + "tokenization": "REMI", + "miditok_version": "2.0.3" +} \ No newline at end of file diff --git a/neutone_midi_sdk/examples/remi/remi_wrapper.py b/neutone_midi_sdk/examples/remi/remi_wrapper.py new file mode 100644 index 0000000..fbbddc2 --- /dev/null +++ b/neutone_midi_sdk/examples/remi/remi_wrapper.py @@ -0,0 +1,64 @@ +import json +import torch +from typing import Dict, List +from neutone_midi_sdk import MidiToMidiBase, NeutoneParameter, TokenData, prepare_token_data + + +class RemiModel(torch.nn.Module): + def forward(self, x: torch.Tensor): + return x + + +class RemiModelWrapper(MidiToMidiBase): + """ + Here you can define overwrite several methods to define your model's functionality. + This most important is "do_forward_pass"; this is where you can define custom behavior, such + as inserting parameters or modifying the sampling logic. + """ + def get_model_name(self) -> str: + return "neutone_remi" + + def get_model_authors(self) -> List[str]: + return ["Julian Lenz"] + + def get_model_short_description(self) -> str: + return "REMI melody generation" + + def get_neutone_parameters(self) -> List[NeutoneParameter]: + return [ + NeutoneParameter("temperature", "sampling temp", default_value=0.6) + ] + + def do_forward_pass(self, tokenized_data: torch.Tensor, params: Dict[str, torch.Tensor]) -> torch.Tensor: + # in reality this model doesn't use params; but here is how you can retrieve it + temperature = params["temperature"] + output = self.model.forward(tokenized_data) + return output + + +if __name__ == "__main__": + + # load config and vocab files + with open("config.json", "r") as fp: + config = json.load(fp) + with open("vocab.json", "r") as fp: + vocab = json.load(fp) + + # Get pre-processed data + tokenizer_type = "REMI" + tokenizer_data: TokenData = prepare_token_data(tokenizer_type, vocab, config) + + # Load model + # Normally you would load a trained model; for this demo, we have the dummy model instead + # scripted_model = torch.jit.load("path_to_trained_model.pt") + scripted_model = RemiModel() + + # Wrap it with SDK and export + wrapped_model = RemiModelWrapper(model=scripted_model, + vocab=vocab, + tokenizer_type=tokenizer_type, + tokenizer_data=tokenizer_data) + scripted_model = torch.jit.script(wrapped_model) + scripted_model.save("neutone_remi_model.pt") + + diff --git a/neutone_midi_sdk/examples/remi/vocab.json b/neutone_midi_sdk/examples/remi/vocab.json new file mode 100644 index 0000000..3e96c77 --- /dev/null +++ b/neutone_midi_sdk/examples/remi/vocab.json @@ -0,0 +1 @@ +{"PAD_None": 0, "BOS_None": 1, "EOS_None": 2, "MASK_None": 3, "Bar_None": 4, "Pitch_21": 5, "Pitch_22": 6, "Pitch_23": 7, "Pitch_24": 8, "Pitch_25": 9, "Pitch_26": 10, "Pitch_27": 11, "Pitch_28": 12, "Pitch_29": 13, "Pitch_30": 14, "Pitch_31": 15, "Pitch_32": 16, "Pitch_33": 17, "Pitch_34": 18, "Pitch_35": 19, "Pitch_36": 20, "Pitch_37": 21, "Pitch_38": 22, "Pitch_39": 23, "Pitch_40": 24, "Pitch_41": 25, "Pitch_42": 26, "Pitch_43": 27, "Pitch_44": 28, "Pitch_45": 29, "Pitch_46": 30, "Pitch_47": 31, "Pitch_48": 32, "Pitch_49": 33, "Pitch_50": 34, "Pitch_51": 35, "Pitch_52": 36, "Pitch_53": 37, "Pitch_54": 38, "Pitch_55": 39, "Pitch_56": 40, "Pitch_57": 41, "Pitch_58": 42, "Pitch_59": 43, "Pitch_60": 44, "Pitch_61": 45, "Pitch_62": 46, "Pitch_63": 47, "Pitch_64": 48, "Pitch_65": 49, "Pitch_66": 50, "Pitch_67": 51, "Pitch_68": 52, "Pitch_69": 53, "Pitch_70": 54, "Pitch_71": 55, "Pitch_72": 56, "Pitch_73": 57, "Pitch_74": 58, "Pitch_75": 59, "Pitch_76": 60, "Pitch_77": 61, "Pitch_78": 62, "Pitch_79": 63, "Pitch_80": 64, "Pitch_81": 65, "Pitch_82": 66, "Pitch_83": 67, "Pitch_84": 68, "Pitch_85": 69, "Pitch_86": 70, "Pitch_87": 71, "Pitch_88": 72, "Pitch_89": 73, "Pitch_90": 74, "Pitch_91": 75, "Pitch_92": 76, "Pitch_93": 77, "Pitch_94": 78, "Pitch_95": 79, "Pitch_96": 80, "Pitch_97": 81, "Pitch_98": 82, "Pitch_99": 83, "Pitch_100": 84, "Pitch_101": 85, "Pitch_102": 86, "Pitch_103": 87, "Pitch_104": 88, "Pitch_105": 89, "Pitch_106": 90, "Pitch_107": 91, "Pitch_108": 92, "Velocity_3": 93, "Velocity_7": 94, "Velocity_11": 95, "Velocity_15": 96, "Velocity_19": 97, "Velocity_23": 98, "Velocity_27": 99, "Velocity_31": 100, "Velocity_35": 101, "Velocity_39": 102, "Velocity_43": 103, "Velocity_47": 104, "Velocity_51": 105, "Velocity_55": 106, "Velocity_59": 107, "Velocity_63": 108, "Velocity_67": 109, "Velocity_71": 110, "Velocity_75": 111, "Velocity_79": 112, "Velocity_83": 113, "Velocity_87": 114, "Velocity_91": 115, "Velocity_95": 116, "Velocity_99": 117, "Velocity_103": 118, "Velocity_107": 119, "Velocity_111": 120, "Velocity_115": 121, "Velocity_119": 122, "Velocity_123": 123, "Velocity_127": 124, "Duration_0.1.4": 125, "Duration_0.2.4": 126, "Duration_0.3.4": 127, "Duration_1.0.4": 128, "Duration_1.1.4": 129, "Duration_1.2.4": 130, "Duration_1.3.4": 131, "Duration_2.0.4": 132, "Duration_2.1.4": 133, "Duration_2.2.4": 134, "Duration_2.3.4": 135, "Duration_3.0.4": 136, "Duration_3.1.4": 137, "Duration_3.2.4": 138, "Duration_3.3.4": 139, "Duration_4.0.4": 140, "Duration_4.1.4": 141, "Duration_4.2.4": 142, "Duration_4.3.4": 143, "Duration_5.0.4": 144, "Duration_5.1.4": 145, "Duration_5.2.4": 146, "Duration_5.3.4": 147, "Duration_6.0.4": 148, "Duration_6.1.4": 149, "Duration_6.2.4": 150, "Duration_6.3.4": 151, "Duration_7.0.4": 152, "Duration_7.1.4": 153, "Duration_7.2.4": 154, "Duration_7.3.4": 155, "Duration_8.0.4": 156, "Duration_8.1.4": 157, "Duration_8.2.4": 158, "Duration_8.3.4": 159, "Duration_9.0.4": 160, "Duration_9.1.4": 161, "Duration_9.2.4": 162, "Duration_9.3.4": 163, "Duration_10.0.4": 164, "Duration_10.1.4": 165, "Duration_10.2.4": 166, "Duration_10.3.4": 167, "Duration_11.0.4": 168, "Duration_11.1.4": 169, "Duration_11.2.4": 170, "Duration_11.3.4": 171, "Duration_12.0.4": 172, "Duration_12.1.4": 173, "Duration_12.2.4": 174, "Duration_12.3.4": 175, "Duration_13.0.4": 176, "Duration_13.1.4": 177, "Duration_13.2.4": 178, "Duration_13.3.4": 179, "Duration_14.0.4": 180, "Duration_14.1.4": 181, "Duration_14.2.4": 182, "Duration_14.3.4": 183, "Duration_15.0.4": 184, "Duration_15.1.4": 185, "Duration_15.2.4": 186, "Duration_15.3.4": 187, "Duration_16.0.2": 188, "Duration_16.1.2": 189, "Duration_17.0.2": 190, "Duration_17.1.2": 191, "Duration_18.0.2": 192, "Duration_18.1.2": 193, "Duration_19.0.2": 194, "Duration_19.1.2": 195, "Duration_20.0.2": 196, "Duration_20.1.2": 197, "Duration_21.0.2": 198, "Duration_21.1.2": 199, "Duration_22.0.2": 200, "Duration_22.1.2": 201, "Duration_23.0.2": 202, "Duration_23.1.2": 203, "Duration_24.0.2": 204, "Position_0": 205, "Position_1": 206, "Position_2": 207, "Position_3": 208, "Position_4": 209, "Position_5": 210, "Position_6": 211, "Position_7": 212, "Position_8": 213, "Position_9": 214, "Position_10": 215, "Position_11": 216, "Position_12": 217, "Position_13": 218, "Position_14": 219, "Position_15": 220} \ No newline at end of file diff --git a/neutone_midi_sdk/neutoneMIDI_SDK.py b/neutone_midi_sdk/neutoneMIDI_SDK.py new file mode 100644 index 0000000..fe84b4f --- /dev/null +++ b/neutone_midi_sdk/neutoneMIDI_SDK.py @@ -0,0 +1,146 @@ +from abc import abstractmethod +from typing import Dict, List, Optional, Tuple, Union + +import torch as tr + +from neutone_midi_sdk import (ContinuousNeutoneParameter, NeutoneMIDIModel, + NeutoneParameterType) +from neutone_midi_sdk.tokenization import (TokenData, convert_midi_to_tokens, + convert_tokens_to_midi) + + +class MidiToMidiBase(NeutoneMIDIModel): + def __init__(self, + model: tr.nn.Module, + vocab: Dict[str, int], + tokenizer_type: str, + tokenizer_data: TokenData, + add_dimension: bool = True): + super().__init__(model, vocab, tokenizer_type, tokenizer_data) + self.add_dimension = add_dimension + + assert all( + p.type == NeutoneParameterType.CONTINUOUS or p.type == NeutoneParameterType.TENSOR + for p in self.get_neutone_parameters() + ), ( + "Only continuous or tensor type parameters are supported in MidiToMidiBase models. " + ) + + # For compatibility with the current plugin, we fill in missing params + # TODO(nic): remove once plugin metadata parsing is implemented + for idx in range(self.n_neutone_parameters, self.MAX_N_NUMERICAL_PARAMS): + unused_p = ContinuousNeutoneParameter( + name="", + description="", + default_value=0.0, + used=False, + ) + self.neutone_parameters_metadata[f"p{idx+1}"] = unused_p.to_metadata_dict() + self.neutone_parameter_names.append(unused_p.name) + self.neutone_parameter_descriptions.append(unused_p.description) + self.neutone_parameter_types.append(unused_p.type.value) + self.neutone_parameter_used.append(unused_p.used) + + def prepare_for_inference(self) -> None: + super().prepare_for_inference() + + @abstractmethod + def do_forward_pass(self, tokenized_data: tr.Tensor, params: Dict[str, tr.Tensor]) -> tr.Tensor: + """ + SDK users can overwrite this method to implement the logic of their models. + The input is a tensor of data that has been tokenized according to the tokenization settings, + i.e. REMI, TSD, etc. + + In addition to the forward pass of your model, you can incorporate additional logic, such as + the control parameters. + + The model should return data in the same format it was input, i.e. REMI-in, REMI-out. This will then + be de-tokenized in the top-level 'forward' method. + """ + pass + + def forward(self, midi_data: tr.Tensor, params: Optional[Dict[str, tr.Tensor]] = None) -> tr.Tensor: + + if params is None: + # This codepath should never be reached, as the plugin always sends parameters. + params = self.get_default_param_values() + + for n in self.neutone_parameter_names: + if n not in params: + raise ValueError(f"Parameter {n} not found in input parameters.") + self.remapped_params[n] = params[n] + + for p in self.neutone_parameters_metadata.keys(): + if self.neutone_parameters_metadata[p]["type"] == NeutoneParameterType.TENSOR.value and \ + self.neutone_parameters_metadata[p]["tokenize"] == str(True): + name = self.neutone_parameters_metadata[p]["name"] + # TODO: change this to token_type=self.tokenizer_type once deprecating HVO_taps + self.remapped_params[name] = convert_midi_to_tokens(midi_data=params[name], + token_type="HVO", + midi_to_token_vocab=self.midi_to_token_vocab, + tokenizer_data=self.tokenizer_data) + + tokenized_data = convert_midi_to_tokens(midi_data=midi_data, + token_type=self.tokenizer_type, + midi_to_token_vocab=self.midi_to_token_vocab, + tokenizer_data=self.tokenizer_data) + + + if self.add_dimension: + tokenized_data = tr.unsqueeze(tokenized_data, dim=0) + model_output = self.do_forward_pass(tokenized_data, self.remapped_params) + if self.add_dimension: + model_output = tr.squeeze(model_output, dim=0) + + output_midi_data = convert_tokens_to_midi(tokens=model_output, + token_type=self.tokenizer_type, + token_to_midi_vocab=self.token_to_midi_vocab, + tokenizer_data=self.tokenizer_data) + + return output_midi_data + + def _get_numerical_default_param_values( + self, + ) -> List[Tuple[str, Union[float, int]]]: + """ + Returns a list of tuples containing the name and default value of each + numerical (float or int) parameter. + For MidiToMidi models, there are always self.MAX_N_NUMERICAL_PARAMS number of + numerical default parameter values, no matter how many parameters have been + defined. This is to prevent empty tensors in some of the internal piping + and queues when the model has no parameters. + This should not be overwritten by SDK users. + """ + result = [] + for p in self.get_neutone_parameters(): + if p.type == NeutoneParameterType.CONTINUOUS: + result.append((p.name, p.default_value)) + if len(result) < self.MAX_N_NUMERICAL_PARAMS: + result.extend( + [ + (f"p{idx + 1}", 0.0) + for idx in range(len(result), self.MAX_N_NUMERICAL_PARAMS) + ] + ) + return result + + def _get_tensor_default_param_values( + self, + ) -> List[Tuple[str, Union[tr.Tensor]]]: + """ + Returns a list of tuples containing the name and default value of each + tensor parameter. + This should not be overwritten by SDK users. + """ + result = [] + for p in self.get_neutone_parameters(): + if p.type == NeutoneParameterType.TENSOR: + result.append((p.name, p.default_value)) + return result + +def generate_fake_token_data(): + token_strings: Dict[str, List[str]] = {"value": ["value"]} + token_floats: Dict[str, List[float]] = {"value": [0.0]} + token_ints: Dict[str, List[int]] = {"value": [0]} + token_data: TokenData = TokenData(token_strings, token_floats, token_ints) + return token_data diff --git a/neutone_midi_sdk/parameter.py b/neutone_midi_sdk/parameter.py new file mode 100644 index 0000000..033d664 --- /dev/null +++ b/neutone_midi_sdk/parameter.py @@ -0,0 +1,96 @@ +import logging +import os +from abc import ABC +from enum import Enum +import torch as tr +from typing import Dict, Union, Tuple + +from neutone_midi_sdk import constants + +logging.basicConfig() +log = logging.getLogger(__name__) +log.setLevel(level=os.environ.get("LOGLEVEL", "INFO")) + +class NeutoneParameterType(Enum): + BASE = "base" + CONTINUOUS = "continuous" + CATEGORICAL = "categorical" + TENSOR = "tensor" + +class NeutoneParameter(ABC): + """ + Defines a Neutone Parameter abstract base class. + + The name and the description of the parameter will be shown as a tooltip + within the UI. This parameter has no functionality and is meant to subclassed. + """ + + def __init__( + self, + name: str, + description: str, + default_value: Union[int, float, str, tr.Tensor], #TODO(nic): optional default_value for tensor case, or default to uniformly populating tensor with default_value + used: bool, + param_type: NeutoneParameterType, + ): + self.name = name + self.description = description + self.default_value = default_value + self.used = used + self.type = param_type + + def to_metadata_dict(self) -> Dict[str, str]: + """Returns a string dictionary containing the metadata of the parameter.""" + return { + "name": self.name, + "description": self.description, + "default_value": str(self.default_value), + "used": str(self.used), + "type": str(self.type.value), + } + +class ContinuousNeutoneParameter(NeutoneParameter): + """ + Defines a continuous Neutone Parameter that the user can use to control a model. + + The name and the description of the parameter will be shown as a tooltip + within the UI. + `default_value` must be between 0 and 1 and will be used as a default in the plugin + when no presets are available. + """ + + def __init__( + self, name: str, description: str, default_value: float, used: bool = True + ): + super().__init__( + name, + description, + default_value, + used, + NeutoneParameterType.CONTINUOUS, + ) + assert ( + 0.0 <= default_value <= 1.0 + ), "`default_value` for continuous params must be between 0 and 1" + +class TensorNeutoneParameter(NeutoneParameter): + """ + Defines a tensor Neutone Parameter that the user can use to control a model. + """ + def __init__(self, name: str, description: str, shape: Tuple[int], default_value: tr.Tensor, used: bool = True): + super().__init__( + name, + description, + default_value, + used, + NeutoneParameterType.TENSOR, + ) + self.shape = shape + + def to_metadata_dict(self) -> Dict[str, str]: + """Returns a string dictionary containing the metadata of the parameter.""" + data = super().to_metadata_dict() + data["shape"] = str(self.shape) + data["default_value"] = str(self.default_value.numpy()) + data["tokenize"] = "True" + return data diff --git a/neutone_midi_sdk/tokenization.py b/neutone_midi_sdk/tokenization.py new file mode 100644 index 0000000..5580484 --- /dev/null +++ b/neutone_midi_sdk/tokenization.py @@ -0,0 +1,509 @@ +import torch +from typing import Dict, List, Tuple + + +class TokenData: + def __init__(self, + strings: Dict[str, List[str]], + floats: Dict[str, List[float]], + ints: Dict[str, List[int]]): + self.strings = strings + self.floats = floats + self.ints = ints + + def get_elements(self) -> Tuple[Dict[str, List[str]], Dict[str, List[float]], Dict[str, List[int]]]: + return self.strings, self.floats, self.ints + + +def convert_midi_to_tokens(midi_data: torch.Tensor, + token_type: str, + midi_to_token_vocab: Dict[str, int], + tokenizer_data: TokenData) \ + -> torch.Tensor: + if token_type == "MIDILike": + return convert_midi_to_midilike_tokens(midi_data, midi_to_token_vocab, tokenizer_data) + + elif token_type == "TSD": + return convert_midi_to_tsd_tokens(midi_data, midi_to_token_vocab, tokenizer_data) + + elif token_type == "REMI": + return convert_midi_to_remi_tokens(midi_data, midi_to_token_vocab, tokenizer_data) + + elif token_type == "HVO": + return convert_midi_to_hvo(midi_data) + + elif token_type == "HVO_taps": + return convert_midi_to_monophonic_hvo(midi_data) + + else: + # Todo: Needs tensor return type; how to assert this? + return torch.zeros((2, 2)) + + +def convert_tokens_to_midi(tokens: torch.Tensor, + token_type: str, + token_to_midi_vocab: Dict[int, str], + tokenizer_data: TokenData) -> torch.Tensor: + if token_type == "MIDILike": + return convert_midilike_tokens_to_midi(tokens, token_to_midi_vocab) + + if token_type == "TSD": + return convert_tsd_tokens_to_midi(tokens, token_to_midi_vocab) + + if token_type == "REMI": + position_granularity: float = tokenizer_data.floats["pos_granularity"][0] + return convert_remi_tokens_to_midi(tokens, token_to_midi_vocab, position_granularity) + + if token_type == "HVO": + return convert_hvo_to_midi(tokens) + + if token_type == "HVO_taps": + return convert_hvo_to_midi(tokens) + + else: + return torch.zeros((2, 2)) + + +""" +Utility Functions +Because torchscript scrictly enforces Typed python, it can often lead to very verbose code. +These functions perform common operations within the tokenisation methods while keeping things +fairly readable. +""" + + +def closest_int(input_list: List[int], value: int) -> int: + # https://stackoverflow.com/questions/12141150/from-list-of-integers-get-number-closest-to-a-given-value + aux: List[int] = [] + for n in input_list: + aux.append(abs(value - n)) + + return input_list[aux.index(min(aux))] + + +def closest_float(input_list: List[float], value: float) -> float: + aux: List[float] = [] + for n in input_list: + aux.append(abs(value - n)) + + return input_list[aux.index(min(aux))] + + +def closest_float_idx(input_list: List[float], value: float) -> int: + aux: List[float] = [] + for n in input_list: + aux.append(abs(value - n)) + + return aux.index(min(aux)) + + +def find_next_note_off_location(midi_data: torch.Tensor, note_value: int) -> float: + location = 0.0 + for message in midi_data: + if message[0] == 1.0 and int(message[1].item()) == note_value: + location = float(message[3].item()) + return location + return location + + +def extract_matching_strings(input_list: List[str], strings: List[str]) -> List[str]: + output_list: List[str] = [] + for string in strings: + for value in input_list: + if string in value: + output_list.append(value) + return output_list + + +def calculate_delta(message: str) -> float: + delta_string = message.split("_")[1] + delta = int(delta_string.split(".")[0]) + int(delta_string.split(".")[1]) * ( + 1 / int(delta_string.split(".")[2])) + return delta + + +def convert_miditok_timing_to_float(message: str) -> float: + timing_string = message.split("_")[1] + float_value = float(timing_string.split(".")[0]) + float(timing_string.split(".")[1]) * ( + 1 / float(timing_string.split(".")[2])) + return float_value + + +""" +Individual tokenization methods are here. For each category of tokenization (i.e. CPWord, REMI, MIDILike, etc.) +there needs to be two functions: midi-to-token, and token-to-midi. +We use an intermediary tuple representation of MIDI data: + +(type(int), pitch(int), velocity(int), timestamp(float)) +A note-on message with pitch of 64 and velocity of 90 on the 3rd beat of measure 5 would be: +(0, 64, 90, 5.75) + +Therefor the tokenization method should convert this tuple to a string found within the vocab, and then +to the appropriate integer. Within the function, it should take care of any pitch/velocity/time quantization that is +needed to fit within the given vocabulary. It will also likely be necessary to convert from the timeformat above to +that of the tokenization method. +""" + + +# ------- +# MIDILike +def convert_midi_to_midilike_tokens(midi_data: torch.Tensor, + vocab: Dict[str, int], + tokenizer_data: TokenData) -> torch.Tensor: + """ + Given neunote MIDI data, a midi_to-token vocab, and available data*, convert the MIDI data + into a tensor of tokens. + *Available data (dict): [timeshift_strings, timeshift_floats, pitch_values, velocity_values] + """ + + token_strings: List[str] = [] + global_timestep: float = 0.0 + quantized_delta: float = 0.0 + + for idx, message in enumerate(midi_data): + + pitch = int(message[1].item()) + pitch = closest_int(tokenizer_data.ints["pitch_values"], pitch) + time = float(message[3].item()) + + # Create timeshift token + if time > global_timestep: + delta = time - global_timestep + if delta > (min(tokenizer_data.floats["timeshift_floats"]) * 0.5): + closest_idx = closest_float_idx(tokenizer_data.floats["timeshift_floats"], delta) + quantized_delta = tokenizer_data.floats["timeshift_floats"][closest_idx] + token_strings.append(tokenizer_data.strings["timeshift_strings"][closest_idx]) + global_timestep += quantized_delta + + # Note on + if message[0] == 0: + velocity = int(message[2].item()) + velocity = closest_int(tokenizer_data.ints["velocity_values"], velocity) + token_strings.append(f"NoteOn_{pitch}") + token_strings.append(f"Velocity_{velocity}") + + # Note off + elif message[0] == 1: + token_strings.append(f"NoteOff_{pitch}") + + # Finally, convert the strings into ints per vocab dictionary + tokens = [vocab[token] for token in token_strings] + tokens = torch.tensor(tokens) + + return tokens + + +def convert_midilike_tokens_to_midi(tokens: torch.Tensor, tokens_to_midi_vocab: Dict[int, str]) -> torch.FloatTensor: + midi_tuples_tensor_list: List[torch.FloatTensor] = [] + global_timestep = 0.0 + velocity = 90 + pitch = 64 + + # Convert int tensor to intermediary string token representation + string_tokens: List[str] = [tokens_to_midi_vocab[token.item()] for token in tokens] + + # Convert these messages into our final tuple format of (message_type, pitch, velocity, timestep) + for idx, message in enumerate(string_tokens): + + if "NoteOn" in message and idx != len(string_tokens) - 1: + if "Velocity" in string_tokens[idx + 1]: + velocity = int(string_tokens[idx + 1].split("_")[1]) + pitch = int(message.split("_")[1]) + midi_tuples_tensor_list.append(torch.FloatTensor([[0.0, + float(pitch), + float(velocity), + float(global_timestep)]])) + + elif "NoteOff" in message: + pitch = int(message.split("_")[1]) + midi_tuples_tensor_list.append(torch.FloatTensor([[1.0, + float(pitch), + 90.0, + float(global_timestep)]])) + + elif "TimeShift" in message: + delta = convert_miditok_timing_to_float(message) + global_timestep += delta + + else: + pass + + midi_output_tensor: torch.FloatTensor = torch.cat(midi_tuples_tensor_list, 0) + + return midi_output_tensor + + +# ------- +# TSD + +def convert_midi_to_tsd_tokens(midi_data: torch.Tensor, + vocab: Dict[str, int], + tokenizer_data: TokenData) -> torch.Tensor: + token_strings: List[str] = [] + global_timestep: float = 0.0 + + for idx, message in enumerate(midi_data): + + pitch = int(message[1].item()) + pitch_quantized = closest_int(tokenizer_data.ints["pitch_values"], pitch) + time = float(message[3].item()) + + # Note on + if message[0] == 0: + + # If later than current position, add timeshift token + if time > global_timestep: + delta = time - global_timestep + if delta > (min(tokenizer_data.floats["timeshift_floats"]) * 0.5): + closest_idx = closest_float_idx(tokenizer_data.floats["timeshift_floats"], delta) + quantized_delta = tokenizer_data.floats["timeshift_floats"][closest_idx] + token_strings.append(tokenizer_data.strings["timeshift_strings"][closest_idx]) + global_timestep += quantized_delta + + # Add pitch and velocity tokens + velocity = int(message[2].item()) + velocity = closest_int(tokenizer_data.ints["velocity_values"], velocity) + token_strings.append(f"Pitch_{pitch_quantized}") + token_strings.append(f"Velocity_{velocity}") + + # Duration + note_off_location: float = find_next_note_off_location(midi_data[idx:], note_value=pitch) + delta = note_off_location - global_timestep + closest_idx = closest_float_idx(tokenizer_data.floats["duration_floats"], delta) + token_strings.append(tokenizer_data.strings["duration_strings"][closest_idx]) + + tokens = [vocab[token] for token in token_strings] + tokens = torch.tensor(tokens) + return tokens + + +def convert_tsd_tokens_to_midi(tokens: torch.Tensor, + tokens_to_midi_vocab: Dict[int, str]) -> torch.FloatTensor: + midi_tuples_tensor_list: List[torch.FloatTensor] = [] + global_timestep = 0.0 + velocity: float = 90.0 + pitch: float = 64.0 + + # Convert int tensor to intermediary string token representation + string_tokens: List[str] = [tokens_to_midi_vocab[token.item()] for token in tokens] + + for idx, message in enumerate(string_tokens): + + # For 'pitch' tokens, we need to ensure that the following two tokens + # are velocity and duration (in either order). Otherwise the model + # has not made a sequentially correct prediction and we will skip to the + # next token + if "Pitch" in message and idx != len(string_tokens) - 2: + + # Check if velocity and duration are present in the next 2 tokens, regardless of order + if all(any(keyword in s for s in string_tokens[idx + 1:idx + 3]) for keyword in ("Velocity", "Duration")): + matching_tokens = extract_matching_strings(string_tokens[idx + 1:idx + 3], + ["Velocity", "Duration"]) + vel_tok, dur_tok = matching_tokens[0], matching_tokens[1] + + pitch = float(message.split("_")[1]) + velocity = float(vel_tok.split("_")[1]) + duration = convert_miditok_timing_to_float(dur_tok) + + midi_tuples_tensor_list.append(torch.FloatTensor([[0.0, + pitch, + velocity, + global_timestep]])) + + midi_tuples_tensor_list.append(torch.FloatTensor([[1.0, + pitch, + velocity, + (global_timestep + duration)]])) + + elif "TimeShift" in message: + global_timestep += calculate_delta(message) + + midi_output_tensor: torch.FloatTensor = torch.cat(midi_tuples_tensor_list, 0) + + return midi_output_tensor + + +# ------- +# REMI + + +def convert_midi_to_remi_tokens(midi_data: torch.Tensor, + vocab: Dict[str, int], + tokenizer_data: TokenData) -> torch.Tensor: + token_strings: List[str] = ["Bar_None"] + global_timestep: float = 0.0 + new_bar: bool = False + note_off_location: float = 0.0 + + for idx, message in enumerate(midi_data): + + pitch = int(message[1].item()) + pitch_quantized = closest_int(tokenizer_data.ints["pitch_values"], pitch) + time = float(message[3].item()) + + if message[0] == 0: + + if time > global_timestep: + delta = time - global_timestep + + # Todo: Remove hard-coded 4/4 timing. But Miditok only supports 4/4 as of June 2023 + # Deal with bar tokens + while delta >= 4.0: + token_strings.append("Bar_None") + global_timestep += 4.0 - global_timestep % 4.0 + delta = time - global_timestep + new_bar = True + + if delta > (min(tokenizer_data.floats["position_floats"]) * 0.5) or new_bar: + closest_idx = closest_float_idx(tokenizer_data.floats["position_floats"], delta) + quantized_delta = tokenizer_data.floats["position_floats"][closest_idx] + token_strings.append(tokenizer_data.strings["position_strings"][closest_idx]) + global_timestep += quantized_delta + + velocity = int(message[2].item()) + velocity = closest_int(tokenizer_data.ints["velocity_values"], velocity) + token_strings.append(f"Pitch_{pitch_quantized}") + token_strings.append(f"Velocity_{velocity}") + + note_off_location = find_next_note_off_location(midi_data[idx:], note_value=pitch) + delta = note_off_location - global_timestep + closest_idx = closest_float_idx(tokenizer_data.floats["duration_floats"], delta) + token_strings.append(tokenizer_data.strings["duration_strings"][closest_idx]) + + tokens = [vocab[token] for token in token_strings] + tokens = torch.tensor(tokens) + + return tokens + + +def convert_remi_tokens_to_midi(tokens: torch.Tensor, + tokens_to_midi_vocab: Dict[int, str], + position_granularity: float) -> torch.FloatTensor: + midi_tuples_tensor_list: List[torch.FloatTensor] = [] + global_timestep = 0.0 + delta: float = 0.0 + velocity: float = 90.0 + pitch: float = 64.0 + + # Convert int tensor to intermediary string token representation + string_tokens: List[str] = [tokens_to_midi_vocab[token.item()] for token in tokens] + + # First "Bar_None" token can be removed. + string_tokens = string_tokens[1:] if string_tokens[0] == "Bar_None" else string_tokens + + for idx, message in enumerate(string_tokens): + + if message == "Bar_None": + global_timestep += 4.0 - global_timestep % 4.0 + + if "Position" in message: + global_timestep += float(message.split("_")[1]) * position_granularity + + if "Pitch" in message and idx != len(string_tokens) - 2: + + if all(any(keyword in s for s in string_tokens[idx + 1:idx + 3]) for keyword in ("Velocity", "Duration")): + matching_tokens = extract_matching_strings(string_tokens[idx + 1:idx + 3], + ["Velocity", "Duration"]) + vel_tok, dur_tok = matching_tokens[0], matching_tokens[1] + + pitch = float(message.split("_")[1]) + velocity = float(vel_tok.split("_")[1]) + duration = convert_miditok_timing_to_float(dur_tok) + + midi_tuples_tensor_list.append(torch.FloatTensor([[0.0, + pitch, + velocity, + global_timestep]])) + + midi_tuples_tensor_list.append(torch.FloatTensor([[1.0, + pitch, + velocity, + (global_timestep + duration)]])) + + midi_output_tensor: torch.FloatTensor = torch.cat(midi_tuples_tensor_list, 0) + + return midi_output_tensor + + +def convert_midi_to_hvo(midi_data: torch.Tensor) -> torch.Tensor: + # Determine total number of 2-bar patterns based on the highest time value in midi_data tensor + mask = (midi_data[:, 0] == 0.0) + num_patterns = int(torch.max(midi_data[mask, 3]) / 8) + 1 + hvo_tensor = torch.zeros((num_patterns, 32, 27)) + + for idx, message in enumerate(midi_data): + if float(message[0].item()) == 0: + time = float(message[3].item()) + hit_location = int(round(time / 0.25) % 32) + pattern = int(time / 8) + velocity = float(message[2].item() / 127.0) + + # Check if the velocity is higher than previous input on this timestep + # TODO: why is this indexed at 2, 11, 20? + # TODO: is this checking the previous input on this timestep? + if velocity > float(hvo_tensor[pattern, hit_location, 11].item()): + offset = (time - (hit_location * 0.25)) / 0.125 + hvo_tensor[pattern, hit_location, 2] = 1.0 + hvo_tensor[pattern, hit_location, 11] = velocity + hvo_tensor[pattern, hit_location, 20] = offset + + return hvo_tensor + +def convert_midi_to_monophonic_hvo(midi_data: torch.Tensor) -> torch.Tensor: + # Determine total number of 2-bar patterns as determined by the highest time value in midi_data tensor + mask = (midi_data[:, 0] == 0.0) + num_patterns = int(torch.max(midi_data[mask, 3]) / 8) + 1 + hvo_tensor = torch.zeros((num_patterns, 32, 3)) + + for idx, message in enumerate(midi_data): + if float(message[0].item()) == 0: + time = float(message[3].item()) + hit_location = int(round(time / 0.25) % 32) + pattern = int(time / 8) + velocity = float(message[2].item() / 127.0) + + # Check if the velocity is higher than previous input on this timestep + # TODO: Do we need a check here? + # if velocity > float(hvo_tensor[pattern, hit_location - 1, 1].item()): + offset = (time - (hit_location * 0.25)) / 0.125 + hvo_tensor[pattern, hit_location, 0] = 1.0 + hvo_tensor[pattern, hit_location, 1] = velocity + hvo_tensor[pattern, hit_location, 2] = offset + + return hvo_tensor + + +def convert_hvo_to_midi(hvo: torch.Tensor) -> torch.Tensor: + midi_tuples_tensor_list: List[torch.FloatTensor] = [] + roland_mapping = [36, 38, 42, 46, 43, 47, 50, 49, 51] + + # Input will be (x, 27, 32) where 'x' is the number of 2-bar patterns + for pattern_idx, two_bar_sequence in enumerate(hvo): + for beat_idx, step in enumerate(two_bar_sequence): + for note_idx, note in enumerate(step[:9]): + if note.item() >= 0.9: + pitch = float(roland_mapping[note_idx]) + velocity = float(two_bar_sequence[beat_idx, (note_idx + 9)].item() * 127) + offset = float(two_bar_sequence[beat_idx, (note_idx + 18)].item() * 0.125) + time = float(beat_idx * 0.25) + offset + time = time if time >= 0.0 else 0.0 + time += float(pattern_idx * 8.0) + + # Note on + midi_tuples_tensor_list.append(torch.FloatTensor([[0.0, + pitch, + velocity, + time]])) + # Note off + midi_tuples_tensor_list.append(torch.FloatTensor([[1.0, + pitch, + 90.0, + (time + 0.15)]])) + if midi_tuples_tensor_list: + midi_output_tensor: torch.Tensor = torch.cat(midi_tuples_tensor_list, 0) + _, indices = torch.sort(midi_output_tensor[:, 3], descending=False) + midi_output_tensor = midi_output_tensor[indices] + else: + midi_output_tensor: torch.Tensor = torch.zeros(2, 3) + + return midi_output_tensor diff --git a/neutone_midi_sdk/training_guides/model_preparation_guide.md b/neutone_midi_sdk/training_guides/model_preparation_guide.md new file mode 100644 index 0000000..98c6295 --- /dev/null +++ b/neutone_midi_sdk/training_guides/model_preparation_guide.md @@ -0,0 +1,83 @@ +# Preparing your model + +This guide assumes that you have already trained a model for a symbolic task, on a tokenization method that +is supported by the SDK (or intend to add your own custom tokenization script). + +Serializing a model in to [Torchscript](https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html) +allows the operations to be replicated in a C++ environment, +such as those used in Digital Audio Workstations. There are two main methods to accomplish this: +- [Scripting](https://pytorch.org/docs/stable/generated/torch.jit.script.html) +- [Tracing](https://pytorch.org/docs/stable/generated/torch.jit.script.html) + +Scripting is the preferable option when possible, as it is more robust to various architectures. +However, in some circumstances tracing is the only option. So far we have found that +[HuggingFace](https://huggingface.co/docs/transformers/v4.17.0/en/serialization) models only +support Tracing. + +### Scripting a model +In case the entire functionality of your model is encoded in the forward() function: +```angular2html +trained_model = MyModel(init_args) # trained torch.nn.Module +scripted_model = torch.jit.script(trained_model) +torch.jit.save(scripted_model, "filename.pt") +``` + +You can combine multiple models / functionalities by combining them into a single forward +function of a new meta-model, and then scripting it. This is particularly useful when your +model has a sampling process that is separate of the forward() function. + +```angular2html +class Sampler(torch.nn.Module): + def __init__(self, args): + super(args, self).__init__() + + def forward(self, x): + # Here you can specify any operations needed for sampling from the output of the model + y = x + 1 + return y + +class FullModel(torch.nn.Module): + def __init__(self, trained_model, sampler): + super(self, trained_model, sampler).__init__() + self.model = trained_model + self.sampler = sampler + + def forward(self, x): + # The full process occurs here + logits = self.model(x) + output = self.sampler(logits) + return output + +# Create the model +trained_model = MyModel(init_args) +sampler = Sampler(init_args) +full_model = FullModel(trained_model, sampler) + +# Now you can script it all together +scripted_model = torch.jit.script(full_model) +torch.jit.save(scripted_model, "filename.pt") + +``` + +### Tracing a model + +Below is an example of how to trace a HuggingFace GPT-2 model: + +```angular2html +with open(os.path.join(train_path, "config.json")) as fp: + config = json.load(fp) +vocab_size = config["vocab_size"] +dummy_input = torch.randint(0, vocab_size, (1, 2048)) +partial_model = GPT2LMHeadModel.from_pretrained(train_path, torchscript=True) +traced_model = torch.jit.trace(partial_model, example_inputs=dummy_input) +torch.jit.save(traced_model, "traced_model.pt") +``` + +Notably, you can combine a Traced module with other components and then Script it. This is helpful +in the above case, as the 'Generate' function requires dynamic processes that cannot be captured with +tracing. Using the combine method detailed above, you can load this Traced module alongside a custom +Generate/Sample function, and then script them all together. + +To be clear, we suggest scripting a model whenever possible. With tracing, it will record +the exact set of operations that are performed on the dummy input. There are much +higher likelihoods of missing important parts of the model's functionality when tracing. diff --git a/neutone_midi_sdk/training_guides/model_training_guide.md b/neutone_midi_sdk/training_guides/model_training_guide.md new file mode 100644 index 0000000..6f92659 --- /dev/null +++ b/neutone_midi_sdk/training_guides/model_training_guide.md @@ -0,0 +1,89 @@ +# Model Training Guide + +This guide assumes you have already chosen a dataset of MIDI files and implemented a +PyTorch model architecture. + +### Tokenization +It's very important that the model is trained on a tokenization method compatible with the +latest version of the Neutone-MIDI SDK: + +- MidiLike +- TSD +- REMI +- HVO + +[Click here](https://miditok.readthedocs.io/en/latest/tokenizations.html) for full documentation +on these methods. + +### Setting Parameters +Each tokenizer has its own individual parameter settings, which are detailed in the link above. +A key functionality in our SDK is that it is robust to all values of these settings; so you can fine-tune them exactly +to the needs of your model. For example, your 'pitch-range' could be very small if you are making a drum model. + +**Important**: When setting up the tokenizer, there is an option for 'additional tokens' +and 'special tokens'. While you can implement them for your training pipeline, we currently +do not translate any of them into actual MIDI data - i.e. a 'Chord' token will not actually +produce a chord. For this reason we generally recommend leaving them as default. + +### Saving your Settings +After tokenizing the dataset, it's important to save both the config and vocab files +in JSON format. This is used by the "Data_Preparation" pipeline to extract necessary information +for MIDI translation. + +Here is an example of tokenizing with MIDILike (more examples on Miditok doc page): + +```angular2html +import argparse +import os +import shutil +import json +from pathlib import Path +from miditok import REMI + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("--midi_path", type=str, help="Path to MIDI Files") + args = parser.parse_args() + MIDI_PATH = args.midi_path + + # Parameters + pitch_range = range(21, 109) + beat_res = {(0, 4): 8, (4, 12): 4} + nb_velocities = 32 + additional_tokens = {'Chord': False, + 'Rest': False, + 'Tempo': False, + 'Program': False, + 'TimeSignature': False, + 'rest_range': (2, 8), # (half, 8 beats) + 'nb_tempos': 32, # nb of tempo bins + 'tempo_range': (40, 250)} # (min, max) + special_tokens = ["PAD", "BOS", "EOS"] + + # Creates the tokenizer convert MIDIs to tokens + print("#---- Tokenizing the data") + tokens_path = Path('tokenized_data') + + # Check if the directory exists + if tokens_path.exists() and tokens_path.is_dir(): + shutil.rmtree(tokens_path) + os.makedirs(tokens_path) + + tokenizer = MIDILike(pitch_range, beat_res, nb_velocities, additional_tokens, special_tokens=special_tokens) + midi_paths = list(Path(MIDI_PATH).glob('**/*.mid')) + list(Path(MIDI_PATH).glob('**/*.midi')) + print(f"Training on {len(midi_paths)} MIDI files.\n") + tokenizer.tokenize_midi_dataset(midi_paths, tokens_path) + + # Save tokenization settings + tokenizer_params = Path('tokenizer_params') + if not os.path.exists(tokenizer_params): + os.makedirs(tokenizer_params) + + with open(os.path.join(tokenizer_params, "vocab.json"), "w") as fp: + json.dump(tokenizer.vocab, fp) + tokenizer.save_params(out_path="tokenizer_params/config.json") +``` + +Once your model is tokenized, it is time to train! As long as it is a PyTorch model that can be scripted or traced (detailed +in the following guide, 'model_preparation_guide') then you are good to go. Happy training! \ No newline at end of file