1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
| import os import json import random import matplotlib.pyplot as plt
""" 假设数据集文件夹中有三类 class_indices.json { "0": "AD", "1": "CN", "2": "MCI" } """ def read_split_data(root: str, val_rate: float = 0.2): random.seed(0) assert os.path.exists(root), "dataset root: {} does not exist.".format(root)
data_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))] data_class.sort() class_indices = dict((k, v) for v, k in enumerate(data_class)) json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4) with open('class_indices.json', 'w') as json_file: json_file.write(json_str)
train_images_path = [] train_images_label = [] val_images_path = [] val_images_label = [] every_class_num = [] supported = [".jpg", ".JPG", ".png", ".PNG"] for cla in data_class: cla_path = os.path.join(root, cla) images = [os.path.join(root, cla, i) for i in os.listdir(cla_path) if os.path.splitext(i)[-1] in supported] images.sort() image_class = class_indices[cla] every_class_num.append(len(images)) val_path = random.sample(images, k=int(len(images) * val_rate))
for img_path in images: if img_path in val_path: val_images_path.append(img_path) val_images_label.append(image_class) else: train_images_path.append(img_path) train_images_label.append(image_class)
print("{} images were found in the dataset.".format(sum(every_class_num))) print("{} images for training.".format(len(train_images_path))) print("{} images for validation.".format(len(val_images_path))) assert len(train_images_path) > 0, "number of training images must greater than 0." assert len(val_images_path) > 0, "number of validation images must greater than 0."
plot_image = True if plot_image: plt.bar(range(len(data_class)), every_class_num, align='center') plt.xticks(range(len(data_class)), data_class) for i, v in enumerate(every_class_num): plt.text(x=i, y=v + 5, s=str(v), ha='center') plt.xlabel('image class') plt.ylabel('number of images') plt.title('data class distribution') plt.show()
return train_images_path, train_images_label, val_images_path, val_images_label
data_path = "D:\data_set" train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(data_path)
|