~~~ import json import os import sys import tempfile import numpy import pandas import tqdm from pytorch_lightning import LightningDataModule, LightningModule import torch import torchaudio from torchvision.transforms import Compose from torch.utils.data import Dataset, DataLoader from nemo.collections.asr.data.audio_to_text import _AudioTextDataset from nemo.collections.asr.models import EncDecCTCModel import nemo_old import nemo_old.vocabs import torchcrepe nemo_old.disable_strict() class TacotronDataset(Dataset): def __init__(self, source, filelist_fn, transform=None): """ :param source: Path to the dataset. :param filelist_fn: File path within source containing audio|transcript lines. :transform: transformation to apply to the audio, (audio, rate) -> (audio, duration) """ self.source = source filelist_path = os.path.join(source, filelist_fn) self.samples = pandas.read_csv(filelist_path, header=None, sep="|") self.transform = transform def __len__(self): return len(self.samples) def __getitem__(self, index): source_fn = self.samples.iloc[index, 0] transcript = self.samples.iloc[index, 1] source_path = os.path.join(self.source, source_fn) audio, rate = torchaudio.load(source_path) duration = audio.shape[-1] / rate sample = { "audio": audio, "sample_rate": rate, "duration": duration, "audio_path": os.path.normpath(source_path), "transcript": transcript, } if self.transform: sample = self.transform(sample) return sample class NormalizeAudio: def __init__( self, resample_rate=22050, lowpass_filter_width=6, rolloff_frequency=0.99, resampling_method="sinc_interpolation", kaiser_window_beta=None, ): self.resample_rate = resample_rate self.lowpass_filter_width = lowpass_filter_width self.rolloff_frequency = rolloff_frequency self.resampling_method = resampling_method self.kaiser_window_beta = kaiser_window_beta def __call__(self, sample): audio = torchaudio.functional.resample( sample["audio"], sample["sample_rate"], self.resample_rate or rate, self.lowpass_filter_width, self.rolloff_frequency, self.resampling_method, self.kaiser_window_beta, ) sample = dict(sample) sample["audio"] = audio sample["sample_rate"] = self.resample_rate return sample class AlignAudio: def __init__(self): super().__init__() self.aligner = EncDecCTCModel.from_pretrained(model_name="asr_talknet_aligner") self.BLANK_TOKEN = self.aligner.decoder.num_classes_with_blank - 1 self.LABEL = self.aligner.decoder.vocabular self.parser = make_vocab( notation="phonemes", punct=True, spaces=True, stresses=False, add_blank_at="last", ) def __call__(self, sample): log_probs, _, greedy_predictions = self.aligner(audio, length) return log_probs def make_vocab( notation="chars", punct=True, spaces=False, stresses=False, add_blank_at="last_but_one", ): """Stolen from Nvidia NeMo. Constructs vocabulary from given parameters. Args: notation (str): Either 'chars' or 'phonemes' as general notation. punct (bool): True if reserve grapheme for basic punctuation. spaces (bool): True if prepend spaces to every punctuation symbol. stresses (bool): True if use phonemes codes with stresses (0-2). add_blank_at: add blank to labels in the specified order ("last" or "last_but_one"), if None then no blank in labels. Returns: (vocabs.Base) Vocabulary """ if notation == "chars": vocab = nemo_old.vocabs.Chars( punct=punct, spaces=spaces, add_blank_at=add_blank_at ) elif notation == "phonemes": vocab = nemo_old.vocabs.Phonemes( punct=punct, stresses=stresses, spaces=spaces, add_blank_at=add_blank_at, ) else: raise ValueError("Unsupported vocab type.") return vocab class PhonemeDurationDataset(Dataset): def __init__(self, source, filelists): self.source = source self.audio_datasets = [TacotronDataset(self.source, x) for x in filelists] self.aligner = None self.parser = None self.BLANK_TOKEN = None self.LABELS = None self.dataset = None self.samples = None def load(self): self.aligner = EncDecCTCModel.from_pretrained(model_name="asr_talknet_aligner") self.parser = make_vocab( notation="phonemes", punct=True, spaces=True, stresses=False, add_blank_at="last", ) self.BLANK_TOKEN = self.aligner.decoder.num_classes_with_blank - 1 self.LABELS = self.aligner.decoder.vocabulary with tempfile.NamedTemporaryFile(mode="w", encoding="utf8") as manifest_fp: self.write_manifest(manifest_fp) self.dataset = _AudioTextDataset( manifest_filepath=manifest_fp.name, sample_rate=22050, parser=self.parser, ) loader = torch.utils.data.DataLoader( dataset=self.dataset, batch_size=1, collate_fn=self.dataset.collate_fn, shuffle=False, ) self.samples = list(loader) def unload(self): del self.aligner del self.parser del self.dataset del self.samples def write_manifest(self, manifest_fp): for audio_dataset in self.audio_datasets: for sample in audio_dataset: transcript = sample["transcript"] audio_path = sample["audio_path"] duration = sample["audio"].shape[-1] / sample["sample_rate"] manifest_fp.write( json.dumps( { "audio_filepath": os.path.normpath(audio_path), "duration": duration, "text": transcript, } ) ) manifest_fp.write("\n") manifest_fp.flush() def __getitem__(self, i): sample = self.samples[i] log_probs, _, _ = self.aligner( input_signal=sample[0].to(self.aligner.device), input_signal_length=sample[1].to(self.aligner.device), ) log_probs = log_probs.cpu().detach().numpy()[0] seq_ids = sample[2][0].cpu().detach().numpy() target_tokens = self.preprocess_tokens(seq_ids) f, p = self.forward_extractor(target_tokens, log_probs, self.BLANK_TOKEN) durs = self.backward_extractor(f, p) dur_key = os.path.normpath( self.dataset.manifest_processor.collection[i].audio_file ) dur_data = { "blanks": torch.tensor(durs[::2], dtype=torch.long).cpu().detach(), "tokens": torch.tensor(durs[1::2], dtype=torch.long).cpu().detach(), } return dur_key, dur_data def __len__(self): return sum([len(x) for x in self.audio_datasets]) def preprocess_tokens(self, tokens): new_tokens = [self.BLANK_TOKEN] for c in tokens: new_tokens.extend([c, self.BLANK_TOKEN]) tokens = new_tokens return tokens def forward_extractor(self, tokens, log_probs, blank): """Computes states f and p.""" n, m = len(tokens), log_probs.shape[0] # `f[s, t]` -- max sum of log probs for `s` first codes # with `t` first timesteps with ending in `tokens[s]`. f = numpy.empty((n + 1, m + 1), dtype=float) f.fill(-(10**9)) p = numpy.empty((n + 1, m + 1), dtype=int) f[0, 0] = 0.0 # Start for s in range(1, n + 1): c = tokens[s - 1] for t in range((s + 1) // 2, m + 1): f[s, t] = log_probs[t - 1, c] # Option #1: prev char is equal to current one. if s == 1 or c == blank or c == tokens[s - 3]: options = f[s : (s - 2 if s > 1 else None) : -1, t - 1] else: # Is not equal to current one. options = f[s : (s - 3 if s > 2 else None) : -1, t - 1] f[s, t] += numpy.max(options) p[s, t] = numpy.argmax(options) return f, p def backward_extractor(self, f, p): """Computes durs from f and p.""" n, m = f.shape n -= 1 m -= 1 durs = numpy.zeros(n, dtype=int) if f[-1, -1] >= f[-2, -1]: s, t = n, m else: s, t = n - 1, m while s > 0: durs[s - 1] += 1 s -= p[s, t] t -= 1 assert durs.shape[0] == n assert numpy.sum(durs) == m assert numpy.all(durs[1::2] > 0) return durs def merge_audio_channels(sample): sample = dict(sample) sample["audio"] = sample["audio"].mean(0) return sample class PhonemeDurationAug: def __init__(self, phonemeKv): self.durations = phonemeKv def __call__(self, sample): key = sample["audio_path"] durations = ( torch.stack( ( self.durations[key]["blanks"], torch.cat((self.durations[key]["tokens"], torch.zeros(1).int())), ), dim=1, ) .view(-1)[:-1] .view(1, -1) ) sample = dict(sample) sample["durations"] = durations return sample class PitchDataset(Dataset): def __init__(self, source, filelists, pitch_hop_length=256): self.hop_length = pitch_hop_length self.source = source self.filelists = filelists self.audio_datasets = [ TacotronDataset(self.source, x, transform=merge_audio_channels) for x in filelists ] self.sample_counts = [len(x) for x in self.audio_datasets] def __len__(self): return sum(self.sample_counts) def __getitem__(self, index): # figure out which dataset contains the index dataset = 0 print("getting index", index, "of", len(self)) while index >= len(self.audio_datasets[dataset]): index -= len(self.audio_datasets[dataset]) dataset += 1 sample = self.audio_datasets[dataset][index] audio_path = sample["audio_path"] audio = sample["audio"] sample_rate = sample["sample_rate"] pitch = crepe_f0((self.hop_length, audio_path, audio, sample_rate)) return audio_path, pitch class PitchAug: def __init__(self, pitchKv) -> None: self.pitches = pitchKv def __call__(self, sample): audio_path = sample["audio_path"] sample = dict(sample) sample["pitch"] = self.pitches[audio_path] return sample def crepe_f0(args): hop_length, audio_path, audio, sample_rate = args audio = audio.reshape(1, *audio.shape) print("processing", audio_path, sample_rate) # hop_length = self.hop_length # hop_length, audio_path, audio, sample_rate = args frequency, confidence = torchcrepe.predict( audio, sample_rate, hop_length, return_periodicity=True, fmin=100, fmax=600, device="cuda", ) time = numpy.arange(confidence.shape[-1]) / 100.0 audio = audio.cpu()[0].numpy() * 2**15 audio_x = numpy.arange(0, len(audio)) / sample_rate frequency = frequency[0].cpu().numpy() confidence = confidence[0].cpu().numpy() x = numpy.arange(0, len(audio), hop_length) / sample_rate freq_interp = numpy.interp(x, time, frequency) conf_interp = numpy.interp(x, time, confidence) audio_interp = numpy.interp(x, audio_x, numpy.absolute(audio)) / 2**15 weights = [0.5, 0.25, 0.25] audio_smooth = numpy.convolve(audio_interp, numpy.array(weights)[::-1], "same") conf_threshold = 0.25 audio_threshold = 0.0005 for i in range(len(freq_interp)): if conf_interp[i] < conf_threshold: freq_interp[i] = 0.0 if audio_smooth[i] < audio_threshold: freq_interp[i] = 0.0 # Hack to make f0 and mel lengths equal if len(audio) % hop_length == 0: freq_interp = numpy.pad(freq_interp, pad_width=[0, 1]) return torch.from_numpy(freq_interp.astype(numpy.float32)) class TalkNetDataLoader(LightningDataModule): def __init__( self, source, pitchKv, phonemeKv, batch_size=1, train_filelist="train_filelist.txt", val_filelist="val_filelist.txt", *args, **kwargs, ) -> None: super().__init__(*args, **kwargs) self.source = source self.train_filelist = train_filelist self.val_filelist = val_filelist self.batch_size = 1 all_pitches = torch.cat([x[x > 20] for x in pitchKv.values()]) self.pitch_mean = all_pitches.mean().item() self.pitch_std = all_pitches.std().item() self.transform = Compose( [ merge_audio_channels, PitchAug(pitchKv), PhonemeDurationAug(phonemeKv), NormalizeAudio(resample_rate=22050), ] ) def train_dataloader(self): return DataLoader( TacotronDataset(self.source, self.train_filelist, transform=self.transform), batch_size=self.batch_size, ) def val_dataloader(self): return DataLoader( TacotronDataset(self.source, self.val_filelist, transform=self.transform), batch_size=self.batch_size, ) def test_dataloader(self): return self.val_dataloader() class Cached: def __init__(self, dataset, cache_path): self.dataset = dataset self.cache_path = cache_path self.loaded = not hasattr(self.dataset, "load") os.makedirs(self.cache_path, exist_ok=True) def dict(self): result = {} for i in range(len(self.dataset)): cache_path = os.path.join(self.cache_path, f"{i}.pt") if os.path.exists(cache_path): key, value = torch.load(cache_path) else: if not self.loaded: self.dataset.load() self.loaded = True key, value = self.dataset[i] torch.save((key, value), cache_path) result[key] = value if self.loaded and hasattr(self.dataset, "unload"): self.dataset.unload() return result def load_talknet_data( source, train_filelist="train_filelist.txt", val_filelist="val_filelist.txt" ): pitchKv = Cached( PitchDataset("./sunset-singing", [train_filelist, val_filelist]), "./pitch-data-cache", ).dict() phonemeKv = Cached( PhonemeDurationDataset("./sunset-singing", [train_filelist, val_filelist]), "./phoneme-data-cache", ).dict() return TalkNetDataLoader( "./sunset-singing", pitchKv, phonemeKv, train_filelist=train_filelist, val_filelist=val_filelist, ) ~~~