| from torch.utils.data import Dataset as TorchDataset | |
| import pandas as pd | |
| import torchaudio | |
| import torch | |
| class ChainsawDataset(TorchDataset): | |
| def __init__(self): | |
| self.path="../datasets/freesound/" | |
| self.ds = pd.read_csv(self.path+"labels.csv") | |
| def __getitem__(self, index): | |
| file, label = self.ds.iloc[index] | |
| x, sr = torchaudio.load(self.path+file) | |
| x = x.squeeze() | |
| return { | |
| 'audio': { | |
| 'path': file, | |
| 'array': x, | |
| 'sampling_rate': torch.tensor(sr), | |
| }, | |
| 'label': torch.tensor(label) | |
| } | |
| def __len__(self): | |
| return len(self.ds) |