1747 15.81 KB 516
import json
import os
import sys
import tempfile
import numpy
import pandas
import tqdm
from pytorch_lightning import LightningDataModule, LightningModule
import torch
import torchaudio
from torchvision.transforms import Compose
from torch.utils.data import Dataset, DataLoader
from nemo.collections.asr.data.audio_to_text import _AudioTextDataset
from nemo.collections.asr.models import EncDecCTCModel
import nemo_old
import nemo_old.vocabs
import torchcrepe
nemo_old.disable_strict()
class TacotronDataset(Dataset):
    def __init__(self, source, filelist_fn, transform=None):
        """
        :param source: Path to the dataset.
        :param filelist_fn: File path within source containing audio|transcript lines.
        :transform: transformation to apply to the audio, (audio, rate) -> (audio, duration)
        """
        self.source = source
        filelist_path = os.path.join(source, filelist_fn)
        self.samples = pandas.read_csv(filelist_path, header=None, sep="|")
        self.transform = transform
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, index):
        source_fn = self.samples.iloc[index, 0]
        transcript = self.samples.iloc[index, 1]
        source_path = os.path.join(self.source, source_fn)
        audio, rate = torchaudio.load(source_path)
        duration = audio.shape[-1] / rate
        sample = {
            "audio": audio,
            "sample_rate": rate,
            "duration": duration,
            "audio_path": os.path.normpath(source_path),
            "transcript": transcript,
        }
        if self.transform:
            sample = self.transform(sample)
        return sample
class NormalizeAudio:
    def __init__(
        self,
        resample_rate=22050,
        lowpass_filter_width=6,
        rolloff_frequency=0.99,
        resampling_method="sinc_interpolation",
        kaiser_window_beta=None,
    ):
        self.resample_rate = resample_rate
        self.lowpass_filter_width = lowpass_filter_width
        self.rolloff_frequency = rolloff_frequency
        self.resampling_method = resampling_method
        self.kaiser_window_beta = kaiser_window_beta
    def __call__(self, sample):
        audio = torchaudio.functional.resample(
            sample["audio"],
            sample["sample_rate"],
            self.resample_rate or rate,
            self.lowpass_filter_width,
            self.rolloff_frequency,
            self.resampling_method,
            self.kaiser_window_beta,
        )
        sample = dict(sample)
        sample["audio"] = audio
        sample["sample_rate"] = self.resample_rate
        return sample
class AlignAudio:
    def __init__(self):
        super().__init__()
        self.aligner = EncDecCTCModel.from_pretrained(model_name="asr_talknet_aligner")
        self.BLANK_TOKEN = self.aligner.decoder.num_classes_with_blank - 1
        self.LABEL = self.aligner.decoder.vocabular
        self.parser = make_vocab(
            notation="phonemes",
            punct=True,
            spaces=True,
            stresses=False,
            add_blank_at="last",
        )
    def __call__(self, sample):
        log_probs, _, greedy_predictions = self.aligner(audio, length)
        return log_probs
def make_vocab(
    notation="chars",
    punct=True,
    spaces=False,
    stresses=False,
    add_blank_at="last_but_one",
):
    """Stolen from Nvidia NeMo. Constructs vocabulary from given parameters.
    Args:
        notation (str): Either 'chars' or 'phonemes' as general notation.
        punct (bool): True if reserve grapheme for basic punctuation.
        spaces (bool): True if prepend spaces to every punctuation symbol.
        stresses (bool): True if use phonemes codes with stresses (0-2).
        add_blank_at: add blank to labels in the specified order ("last" or "last_but_one"),
         if None then no blank in labels.
    Returns:
        (vocabs.Base) Vocabulary
    """
    if notation == "chars":
        vocab = nemo_old.vocabs.Chars(
            punct=punct, spaces=spaces, add_blank_at=add_blank_at
        )
    elif notation == "phonemes":
        vocab = nemo_old.vocabs.Phonemes(
            punct=punct,
            stresses=stresses,
            spaces=spaces,
            add_blank_at=add_blank_at,
        )
    else:
        raise ValueError("Unsupported vocab type.")
    return vocab
class PhonemeDurationDataset(Dataset):
    def __init__(self, source, filelists):
        self.source = source
        self.audio_datasets = [TacotronDataset(self.source, x) for x in filelists]
        self.aligner = None
        self.parser = None
        self.BLANK_TOKEN = None
        self.LABELS = None
        self.dataset = None
        self.samples = None
    def load(self):
        self.aligner = EncDecCTCModel.from_pretrained(model_name="asr_talknet_aligner")
        self.parser = make_vocab(
            notation="phonemes",
            punct=True,
            spaces=True,
            stresses=False,
            add_blank_at="last",
        )
        self.BLANK_TOKEN = self.aligner.decoder.num_classes_with_blank - 1
        self.LABELS = self.aligner.decoder.vocabulary
        with tempfile.NamedTemporaryFile(mode="w", encoding="utf8") as manifest_fp:
            self.write_manifest(manifest_fp)
            self.dataset = _AudioTextDataset(
                manifest_filepath=manifest_fp.name,
                sample_rate=22050,
                parser=self.parser,
            )
            loader = torch.utils.data.DataLoader(
                dataset=self.dataset,
                batch_size=1,
                collate_fn=self.dataset.collate_fn,
                shuffle=False,
            )
            self.samples = list(loader)
    def unload(self):
        del self.aligner
        del self.parser
        del self.dataset
        del self.samples
    def write_manifest(self, manifest_fp):
        for audio_dataset in self.audio_datasets:
            for sample in audio_dataset:
                transcript = sample["transcript"]
                audio_path = sample["audio_path"]
                duration = sample["audio"].shape[-1] / sample["sample_rate"]
                manifest_fp.write(
                    json.dumps(
                        {
                            "audio_filepath": os.path.normpath(audio_path),
                            "duration": duration,
                            "text": transcript,
                        }
                    )
                )
                manifest_fp.write("\n")
        manifest_fp.flush()
    def __getitem__(self, i):
        sample = self.samples[i]
        log_probs, _, _ = self.aligner(
            input_signal=sample[0].to(self.aligner.device),
            input_signal_length=sample[1].to(self.aligner.device),
        )
        log_probs = log_probs.cpu().detach().numpy()[0]
        seq_ids = sample[2][0].cpu().detach().numpy()
        target_tokens = self.preprocess_tokens(seq_ids)
        f, p = self.forward_extractor(target_tokens, log_probs, self.BLANK_TOKEN)
        durs = self.backward_extractor(f, p)
        dur_key = os.path.normpath(
            self.dataset.manifest_processor.collection[i].audio_file
        )
        dur_data = {
            "blanks": torch.tensor(durs[::2], dtype=torch.long).cpu().detach(),
            "tokens": torch.tensor(durs[1::2], dtype=torch.long).cpu().detach(),
        }
        return dur_key, dur_data
    def __len__(self):
        return sum([len(x) for x in self.audio_datasets])
    def preprocess_tokens(self, tokens):
        new_tokens = [self.BLANK_TOKEN]
        for c in tokens:
            new_tokens.extend([c, self.BLANK_TOKEN])
        tokens = new_tokens
        return tokens
    def forward_extractor(self, tokens, log_probs, blank):
        """Computes states f and p."""
        n, m = len(tokens), log_probs.shape[0]
        # `f[s, t]` -- max sum of log probs for `s` first codes
        # with `t` first timesteps with ending in `tokens[s]`.
        f = numpy.empty((n + 1, m + 1), dtype=float)
        f.fill(-(10**9))
        p = numpy.empty((n + 1, m + 1), dtype=int)
        f[0, 0] = 0.0  # Start
        for s in range(1, n + 1):
            c = tokens[s - 1]
            for t in range((s + 1) // 2, m + 1):
                f[s, t] = log_probs[t - 1, c]
                # Option #1: prev char is equal to current one.
                if s == 1 or c == blank or c == tokens[s - 3]:
                    options = f[s : (s - 2 if s > 1 else None) : -1, t - 1]
                else:  # Is not equal to current one.
                    options = f[s : (s - 3 if s > 2 else None) : -1, t - 1]
                f[s, t] += numpy.max(options)
                p[s, t] = numpy.argmax(options)
        return f, p
    def backward_extractor(self, f, p):
        """Computes durs from f and p."""
        n, m = f.shape
        n -= 1
        m -= 1
        durs = numpy.zeros(n, dtype=int)
        if f[-1, -1] >= f[-2, -1]:
            s, t = n, m
        else:
            s, t = n - 1, m
        while s > 0:
            durs[s - 1] += 1
            s -= p[s, t]
            t -= 1
        assert durs.shape[0] == n
        assert numpy.sum(durs) == m
        assert numpy.all(durs[1::2] > 0)
        return durs
def merge_audio_channels(sample):
    sample = dict(sample)
    sample["audio"] = sample["audio"].mean(0)
    return sample
class PhonemeDurationAug:
    def __init__(self, phonemeKv):
        self.durations = phonemeKv
    def __call__(self, sample):
        key = sample["audio_path"]
        durations = (
            torch.stack(
                (
                    self.durations[key]["blanks"],
                    torch.cat((self.durations[key]["tokens"], torch.zeros(1).int())),
                ),
                dim=1,
            )
            .view(-1)[:-1]
            .view(1, -1)
        )
        sample = dict(sample)
        sample["durations"] = durations
        return sample
class PitchDataset(Dataset):
    def __init__(self, source, filelists, pitch_hop_length=256):
        self.hop_length = pitch_hop_length
        self.source = source
        self.filelists = filelists
        self.audio_datasets = [
            TacotronDataset(self.source, x, transform=merge_audio_channels)
            for x in filelists
        ]
        self.sample_counts = [len(x) for x in self.audio_datasets]
    def __len__(self):
        return sum(self.sample_counts)
    def __getitem__(self, index):
        # figure out which dataset contains the index
        dataset = 0
        print("getting index", index, "of", len(self))
        while index >= len(self.audio_datasets[dataset]):
            index -= len(self.audio_datasets[dataset])
            dataset += 1
        sample = self.audio_datasets[dataset][index]
        audio_path = sample["audio_path"]
        audio = sample["audio"]
        sample_rate = sample["sample_rate"]
        pitch = crepe_f0((self.hop_length, audio_path, audio, sample_rate))
        return audio_path, pitch
class PitchAug:
    def __init__(self, pitchKv) -> None:
        self.pitches = pitchKv
    def __call__(self, sample):
        audio_path = sample["audio_path"]
        sample = dict(sample)
        sample["pitch"] = self.pitches[audio_path]
        return sample
def crepe_f0(args):
    hop_length, audio_path, audio, sample_rate = args
    audio = audio.reshape(1, *audio.shape)
    print("processing", audio_path, sample_rate)
    # hop_length = self.hop_length
    # hop_length, audio_path, audio, sample_rate = args
    frequency, confidence = torchcrepe.predict(
        audio,
        sample_rate,
        hop_length,
        return_periodicity=True,
        fmin=100,
        fmax=600,
        device="cuda",
    )
    time = numpy.arange(confidence.shape[-1]) / 100.0
    audio = audio.cpu()[0].numpy() * 2**15
    audio_x = numpy.arange(0, len(audio)) / sample_rate
    frequency = frequency[0].cpu().numpy()
    confidence = confidence[0].cpu().numpy()
    x = numpy.arange(0, len(audio), hop_length) / sample_rate
    freq_interp = numpy.interp(x, time, frequency)
    conf_interp = numpy.interp(x, time, confidence)
    audio_interp = numpy.interp(x, audio_x, numpy.absolute(audio)) / 2**15
    weights = [0.5, 0.25, 0.25]
    audio_smooth = numpy.convolve(audio_interp, numpy.array(weights)[::-1], "same")
    conf_threshold = 0.25
    audio_threshold = 0.0005
    for i in range(len(freq_interp)):
        if conf_interp[i] < conf_threshold:
            freq_interp[i] = 0.0
        if audio_smooth[i] < audio_threshold:
            freq_interp[i] = 0.0
    # Hack to make f0 and mel lengths equal
    if len(audio) % hop_length == 0:
        freq_interp = numpy.pad(freq_interp, pad_width=[0, 1])
    return torch.from_numpy(freq_interp.astype(numpy.float32))
class TalkNetDataLoader(LightningDataModule):
    def __init__(
        self,
        source,
        pitchKv,
        phonemeKv,
        batch_size=1,
        train_filelist="train_filelist.txt",
        val_filelist="val_filelist.txt",
        *args,
        **kwargs,
    ) -> None:
        super().__init__(*args, **kwargs)
        self.source = source
        self.train_filelist = train_filelist
        self.val_filelist = val_filelist
        self.batch_size = 1
        all_pitches = torch.cat([x[x > 20] for x in pitchKv.values()])
        self.pitch_mean = all_pitches.mean().item()
        self.pitch_std = all_pitches.std().item()
        self.transform = Compose(
            [
                merge_audio_channels,
                PitchAug(pitchKv),
                PhonemeDurationAug(phonemeKv),
                NormalizeAudio(resample_rate=22050),
            ]
        )
    def train_dataloader(self):
        return DataLoader(
            TacotronDataset(self.source, self.train_filelist, transform=self.transform),
            batch_size=self.batch_size,
        )
    def val_dataloader(self):
        return DataLoader(
            TacotronDataset(self.source, self.val_filelist, transform=self.transform),
            batch_size=self.batch_size,
        )
    def test_dataloader(self):
        return self.val_dataloader()
class Cached:
    def __init__(self, dataset, cache_path):
        self.dataset = dataset
        self.cache_path = cache_path
        self.loaded = not hasattr(self.dataset, "load")
        os.makedirs(self.cache_path, exist_ok=True)
    def dict(self):
        result = {}
        for i in range(len(self.dataset)):
            cache_path = os.path.join(self.cache_path, f"{i}.pt")
            if os.path.exists(cache_path):
                key, value = torch.load(cache_path)
            else:
                if not self.loaded:
                    self.dataset.load()
                    self.loaded = True
                key, value = self.dataset[i]
                torch.save((key, value), cache_path)
            result[key] = value
        if self.loaded and hasattr(self.dataset, "unload"):
            self.dataset.unload()
        return result
def load_talknet_data(
    source, train_filelist="train_filelist.txt", val_filelist="val_filelist.txt"
):
    pitchKv = Cached(
        PitchDataset("./sunset-singing", [train_filelist, val_filelist]),
        "./pitch-data-cache",
    ).dict()
    phonemeKv = Cached(
        PhonemeDurationDataset("./sunset-singing", [train_filelist, val_filelist]),
        "./phoneme-data-cache",
    ).dict()
    return TalkNetDataLoader(
        "./sunset-singing",
        pitchKv,
        phonemeKv,
        train_filelist=train_filelist,
        val_filelist=val_filelist,
    )
                         by Guest
                         by Guest
                         by Guest
                         by Guest
                         by Guest