1747 15.81 KB 516
import json
import os
import sys
import tempfile
import numpy
import pandas
import tqdm
from pytorch_lightning import LightningDataModule, LightningModule
import torch
import torchaudio
from torchvision.transforms import Compose
from torch.utils.data import Dataset, DataLoader
from nemo.collections.asr.data.audio_to_text import _AudioTextDataset
from nemo.collections.asr.models import EncDecCTCModel
import nemo_old
import nemo_old.vocabs
import torchcrepe
nemo_old.disable_strict()
class TacotronDataset(Dataset):
def __init__(self, source, filelist_fn, transform=None):
"""
:param source: Path to the dataset.
:param filelist_fn: File path within source containing audio|transcript lines.
:transform: transformation to apply to the audio, (audio, rate) -> (audio, duration)
"""
self.source = source
filelist_path = os.path.join(source, filelist_fn)
self.samples = pandas.read_csv(filelist_path, header=None, sep="|")
self.transform = transform
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
source_fn = self.samples.iloc[index, 0]
transcript = self.samples.iloc[index, 1]
source_path = os.path.join(self.source, source_fn)
audio, rate = torchaudio.load(source_path)
duration = audio.shape[-1] / rate
sample = {
"audio": audio,
"sample_rate": rate,
"duration": duration,
"audio_path": os.path.normpath(source_path),
"transcript": transcript,
}
if self.transform:
sample = self.transform(sample)
return sample
class NormalizeAudio:
def __init__(
self,
resample_rate=22050,
lowpass_filter_width=6,
rolloff_frequency=0.99,
resampling_method="sinc_interpolation",
kaiser_window_beta=None,
):
self.resample_rate = resample_rate
self.lowpass_filter_width = lowpass_filter_width
self.rolloff_frequency = rolloff_frequency
self.resampling_method = resampling_method
self.kaiser_window_beta = kaiser_window_beta
def __call__(self, sample):
audio = torchaudio.functional.resample(
sample["audio"],
sample["sample_rate"],
self.resample_rate or rate,
self.lowpass_filter_width,
self.rolloff_frequency,
self.resampling_method,
self.kaiser_window_beta,
)
sample = dict(sample)
sample["audio"] = audio
sample["sample_rate"] = self.resample_rate
return sample
class AlignAudio:
def __init__(self):
super().__init__()
self.aligner = EncDecCTCModel.from_pretrained(model_name="asr_talknet_aligner")
self.BLANK_TOKEN = self.aligner.decoder.num_classes_with_blank - 1
self.LABEL = self.aligner.decoder.vocabular
self.parser = make_vocab(
notation="phonemes",
punct=True,
spaces=True,
stresses=False,
add_blank_at="last",
)
def __call__(self, sample):
log_probs, _, greedy_predictions = self.aligner(audio, length)
return log_probs
def make_vocab(
notation="chars",
punct=True,
spaces=False,
stresses=False,
add_blank_at="last_but_one",
):
"""Stolen from Nvidia NeMo. Constructs vocabulary from given parameters.
Args:
notation (str): Either 'chars' or 'phonemes' as general notation.
punct (bool): True if reserve grapheme for basic punctuation.
spaces (bool): True if prepend spaces to every punctuation symbol.
stresses (bool): True if use phonemes codes with stresses (0-2).
add_blank_at: add blank to labels in the specified order ("last" or "last_but_one"),
if None then no blank in labels.
Returns:
(vocabs.Base) Vocabulary
"""
if notation == "chars":
vocab = nemo_old.vocabs.Chars(
punct=punct, spaces=spaces, add_blank_at=add_blank_at
)
elif notation == "phonemes":
vocab = nemo_old.vocabs.Phonemes(
punct=punct,
stresses=stresses,
spaces=spaces,
add_blank_at=add_blank_at,
)
else:
raise ValueError("Unsupported vocab type.")
return vocab
class PhonemeDurationDataset(Dataset):
def __init__(self, source, filelists):
self.source = source
self.audio_datasets = [TacotronDataset(self.source, x) for x in filelists]
self.aligner = None
self.parser = None
self.BLANK_TOKEN = None
self.LABELS = None
self.dataset = None
self.samples = None
def load(self):
self.aligner = EncDecCTCModel.from_pretrained(model_name="asr_talknet_aligner")
self.parser = make_vocab(
notation="phonemes",
punct=True,
spaces=True,
stresses=False,
add_blank_at="last",
)
self.BLANK_TOKEN = self.aligner.decoder.num_classes_with_blank - 1
self.LABELS = self.aligner.decoder.vocabulary
with tempfile.NamedTemporaryFile(mode="w", encoding="utf8") as manifest_fp:
self.write_manifest(manifest_fp)
self.dataset = _AudioTextDataset(
manifest_filepath=manifest_fp.name,
sample_rate=22050,
parser=self.parser,
)
loader = torch.utils.data.DataLoader(
dataset=self.dataset,
batch_size=1,
collate_fn=self.dataset.collate_fn,
shuffle=False,
)
self.samples = list(loader)
def unload(self):
del self.aligner
del self.parser
del self.dataset
del self.samples
def write_manifest(self, manifest_fp):
for audio_dataset in self.audio_datasets:
for sample in audio_dataset:
transcript = sample["transcript"]
audio_path = sample["audio_path"]
duration = sample["audio"].shape[-1] / sample["sample_rate"]
manifest_fp.write(
json.dumps(
{
"audio_filepath": os.path.normpath(audio_path),
"duration": duration,
"text": transcript,
}
)
)
manifest_fp.write("\n")
manifest_fp.flush()
def __getitem__(self, i):
sample = self.samples[i]
log_probs, _, _ = self.aligner(
input_signal=sample[0].to(self.aligner.device),
input_signal_length=sample[1].to(self.aligner.device),
)
log_probs = log_probs.cpu().detach().numpy()[0]
seq_ids = sample[2][0].cpu().detach().numpy()
target_tokens = self.preprocess_tokens(seq_ids)
f, p = self.forward_extractor(target_tokens, log_probs, self.BLANK_TOKEN)
durs = self.backward_extractor(f, p)
dur_key = os.path.normpath(
self.dataset.manifest_processor.collection[i].audio_file
)
dur_data = {
"blanks": torch.tensor(durs[::2], dtype=torch.long).cpu().detach(),
"tokens": torch.tensor(durs[1::2], dtype=torch.long).cpu().detach(),
}
return dur_key, dur_data
def __len__(self):
return sum([len(x) for x in self.audio_datasets])
def preprocess_tokens(self, tokens):
new_tokens = [self.BLANK_TOKEN]
for c in tokens:
new_tokens.extend([c, self.BLANK_TOKEN])
tokens = new_tokens
return tokens
def forward_extractor(self, tokens, log_probs, blank):
"""Computes states f and p."""
n, m = len(tokens), log_probs.shape[0]
# `f[s, t]` -- max sum of log probs for `s` first codes
# with `t` first timesteps with ending in `tokens[s]`.
f = numpy.empty((n + 1, m + 1), dtype=float)
f.fill(-(10**9))
p = numpy.empty((n + 1, m + 1), dtype=int)
f[0, 0] = 0.0 # Start
for s in range(1, n + 1):
c = tokens[s - 1]
for t in range((s + 1) // 2, m + 1):
f[s, t] = log_probs[t - 1, c]
# Option #1: prev char is equal to current one.
if s == 1 or c == blank or c == tokens[s - 3]:
options = f[s : (s - 2 if s > 1 else None) : -1, t - 1]
else: # Is not equal to current one.
options = f[s : (s - 3 if s > 2 else None) : -1, t - 1]
f[s, t] += numpy.max(options)
p[s, t] = numpy.argmax(options)
return f, p
def backward_extractor(self, f, p):
"""Computes durs from f and p."""
n, m = f.shape
n -= 1
m -= 1
durs = numpy.zeros(n, dtype=int)
if f[-1, -1] >= f[-2, -1]:
s, t = n, m
else:
s, t = n - 1, m
while s > 0:
durs[s - 1] += 1
s -= p[s, t]
t -= 1
assert durs.shape[0] == n
assert numpy.sum(durs) == m
assert numpy.all(durs[1::2] > 0)
return durs
def merge_audio_channels(sample):
sample = dict(sample)
sample["audio"] = sample["audio"].mean(0)
return sample
class PhonemeDurationAug:
def __init__(self, phonemeKv):
self.durations = phonemeKv
def __call__(self, sample):
key = sample["audio_path"]
durations = (
torch.stack(
(
self.durations[key]["blanks"],
torch.cat((self.durations[key]["tokens"], torch.zeros(1).int())),
),
dim=1,
)
.view(-1)[:-1]
.view(1, -1)
)
sample = dict(sample)
sample["durations"] = durations
return sample
class PitchDataset(Dataset):
def __init__(self, source, filelists, pitch_hop_length=256):
self.hop_length = pitch_hop_length
self.source = source
self.filelists = filelists
self.audio_datasets = [
TacotronDataset(self.source, x, transform=merge_audio_channels)
for x in filelists
]
self.sample_counts = [len(x) for x in self.audio_datasets]
def __len__(self):
return sum(self.sample_counts)
def __getitem__(self, index):
# figure out which dataset contains the index
dataset = 0
print("getting index", index, "of", len(self))
while index >= len(self.audio_datasets[dataset]):
index -= len(self.audio_datasets[dataset])
dataset += 1
sample = self.audio_datasets[dataset][index]
audio_path = sample["audio_path"]
audio = sample["audio"]
sample_rate = sample["sample_rate"]
pitch = crepe_f0((self.hop_length, audio_path, audio, sample_rate))
return audio_path, pitch
class PitchAug:
def __init__(self, pitchKv) -> None:
self.pitches = pitchKv
def __call__(self, sample):
audio_path = sample["audio_path"]
sample = dict(sample)
sample["pitch"] = self.pitches[audio_path]
return sample
def crepe_f0(args):
hop_length, audio_path, audio, sample_rate = args
audio = audio.reshape(1, *audio.shape)
print("processing", audio_path, sample_rate)
# hop_length = self.hop_length
# hop_length, audio_path, audio, sample_rate = args
frequency, confidence = torchcrepe.predict(
audio,
sample_rate,
hop_length,
return_periodicity=True,
fmin=100,
fmax=600,
device="cuda",
)
time = numpy.arange(confidence.shape[-1]) / 100.0
audio = audio.cpu()[0].numpy() * 2**15
audio_x = numpy.arange(0, len(audio)) / sample_rate
frequency = frequency[0].cpu().numpy()
confidence = confidence[0].cpu().numpy()
x = numpy.arange(0, len(audio), hop_length) / sample_rate
freq_interp = numpy.interp(x, time, frequency)
conf_interp = numpy.interp(x, time, confidence)
audio_interp = numpy.interp(x, audio_x, numpy.absolute(audio)) / 2**15
weights = [0.5, 0.25, 0.25]
audio_smooth = numpy.convolve(audio_interp, numpy.array(weights)[::-1], "same")
conf_threshold = 0.25
audio_threshold = 0.0005
for i in range(len(freq_interp)):
if conf_interp[i] < conf_threshold:
freq_interp[i] = 0.0
if audio_smooth[i] < audio_threshold:
freq_interp[i] = 0.0
# Hack to make f0 and mel lengths equal
if len(audio) % hop_length == 0:
freq_interp = numpy.pad(freq_interp, pad_width=[0, 1])
return torch.from_numpy(freq_interp.astype(numpy.float32))
class TalkNetDataLoader(LightningDataModule):
def __init__(
self,
source,
pitchKv,
phonemeKv,
batch_size=1,
train_filelist="train_filelist.txt",
val_filelist="val_filelist.txt",
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.source = source
self.train_filelist = train_filelist
self.val_filelist = val_filelist
self.batch_size = 1
all_pitches = torch.cat([x[x > 20] for x in pitchKv.values()])
self.pitch_mean = all_pitches.mean().item()
self.pitch_std = all_pitches.std().item()
self.transform = Compose(
[
merge_audio_channels,
PitchAug(pitchKv),
PhonemeDurationAug(phonemeKv),
NormalizeAudio(resample_rate=22050),
]
)
def train_dataloader(self):
return DataLoader(
TacotronDataset(self.source, self.train_filelist, transform=self.transform),
batch_size=self.batch_size,
)
def val_dataloader(self):
return DataLoader(
TacotronDataset(self.source, self.val_filelist, transform=self.transform),
batch_size=self.batch_size,
)
def test_dataloader(self):
return self.val_dataloader()
class Cached:
def __init__(self, dataset, cache_path):
self.dataset = dataset
self.cache_path = cache_path
self.loaded = not hasattr(self.dataset, "load")
os.makedirs(self.cache_path, exist_ok=True)
def dict(self):
result = {}
for i in range(len(self.dataset)):
cache_path = os.path.join(self.cache_path, f"{i}.pt")
if os.path.exists(cache_path):
key, value = torch.load(cache_path)
else:
if not self.loaded:
self.dataset.load()
self.loaded = True
key, value = self.dataset[i]
torch.save((key, value), cache_path)
result[key] = value
if self.loaded and hasattr(self.dataset, "unload"):
self.dataset.unload()
return result
def load_talknet_data(
source, train_filelist="train_filelist.txt", val_filelist="val_filelist.txt"
):
pitchKv = Cached(
PitchDataset("./sunset-singing", [train_filelist, val_filelist]),
"./pitch-data-cache",
).dict()
phonemeKv = Cached(
PhonemeDurationDataset("./sunset-singing", [train_filelist, val_filelist]),
"./phoneme-data-cache",
).dict()
return TalkNetDataLoader(
"./sunset-singing",
pitchKv,
phonemeKv,
train_filelist=train_filelist,
val_filelist=val_filelist,
)
by Guest
by Guest
by Guest
by Guest
by Guest