update search function to match monai 1.2
Browse files- configs/metadata.json +2 -1
- scripts/prepare_datalist.py +5 -5
- scripts/search.py +2 -5
configs/metadata.json
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
{
|
| 2 |
"schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220324.json",
|
| 3 |
-
"version": "0.4.
|
| 4 |
"changelog": {
|
|
|
|
| 5 |
"0.4.1": "fix the wrong GPU index issue of multi-node",
|
| 6 |
"0.4.0": "remove error dollar symbol in readme",
|
| 7 |
"0.3.9": "add cpu ram requirement in readme",
|
|
|
|
| 1 |
{
|
| 2 |
"schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220324.json",
|
| 3 |
+
"version": "0.4.2",
|
| 4 |
"changelog": {
|
| 5 |
+
"0.4.2": "update search function to match monai 1.2",
|
| 6 |
"0.4.1": "fix the wrong GPU index issue of multi-node",
|
| 7 |
"0.4.0": "remove error dollar symbol in readme",
|
| 8 |
"0.3.9": "add cpu ram requirement in readme",
|
scripts/prepare_datalist.py
CHANGED
|
@@ -11,11 +11,10 @@ def produce_sample_dict(line: str):
|
|
| 11 |
return {"label": line, "image": line.replace("labelsTr", "imagesTr")}
|
| 12 |
|
| 13 |
|
| 14 |
-
def produce_datalist(dataset_dir: str):
|
| 15 |
"""
|
| 16 |
This function is used to split the dataset.
|
| 17 |
-
It will produce
|
| 18 |
-
into val and test sets.
|
| 19 |
"""
|
| 20 |
|
| 21 |
samples = sorted(glob.glob(os.path.join(dataset_dir, "labelsTr", "*"), recursive=True))
|
|
@@ -23,7 +22,7 @@ def produce_datalist(dataset_dir: str):
|
|
| 23 |
datalist = []
|
| 24 |
for line in samples:
|
| 25 |
datalist.append(produce_sample_dict(line))
|
| 26 |
-
train_list, other_list = train_test_split(datalist, train_size=
|
| 27 |
val_list, test_list = train_test_split(other_list, train_size=0.66)
|
| 28 |
|
| 29 |
return {"training": train_list, "validation": val_list, "testing": test_list}
|
|
@@ -37,7 +36,7 @@ def main(args):
|
|
| 37 |
output_json = args.output
|
| 38 |
# produce deterministic data splits
|
| 39 |
monai.utils.set_determinism(seed=123)
|
| 40 |
-
datalist = produce_datalist(dataset_dir=data_file_base_dir)
|
| 41 |
with open(output_json, "w") as f:
|
| 42 |
json.dump(datalist, f, ensure_ascii=True, indent=4)
|
| 43 |
|
|
@@ -53,6 +52,7 @@ if __name__ == "__main__":
|
|
| 53 |
parser.add_argument(
|
| 54 |
"--output", type=str, default="dataset_0.json", help="relative path of output datalist json file."
|
| 55 |
)
|
|
|
|
| 56 |
args = parser.parse_args()
|
| 57 |
|
| 58 |
main(args)
|
|
|
|
| 11 |
return {"label": line, "image": line.replace("labelsTr", "imagesTr")}
|
| 12 |
|
| 13 |
|
| 14 |
+
def produce_datalist(dataset_dir: str, train_size: int = 196):
|
| 15 |
"""
|
| 16 |
This function is used to split the dataset.
|
| 17 |
+
It will produce "train_size" number of samples for training.
|
|
|
|
| 18 |
"""
|
| 19 |
|
| 20 |
samples = sorted(glob.glob(os.path.join(dataset_dir, "labelsTr", "*"), recursive=True))
|
|
|
|
| 22 |
datalist = []
|
| 23 |
for line in samples:
|
| 24 |
datalist.append(produce_sample_dict(line))
|
| 25 |
+
train_list, other_list = train_test_split(datalist, train_size=train_size)
|
| 26 |
val_list, test_list = train_test_split(other_list, train_size=0.66)
|
| 27 |
|
| 28 |
return {"training": train_list, "validation": val_list, "testing": test_list}
|
|
|
|
| 36 |
output_json = args.output
|
| 37 |
# produce deterministic data splits
|
| 38 |
monai.utils.set_determinism(seed=123)
|
| 39 |
+
datalist = produce_datalist(dataset_dir=data_file_base_dir, train_size=args.train_size)
|
| 40 |
with open(output_json, "w") as f:
|
| 41 |
json.dump(datalist, f, ensure_ascii=True, indent=4)
|
| 42 |
|
|
|
|
| 52 |
parser.add_argument(
|
| 53 |
"--output", type=str, default="dataset_0.json", help="relative path of output datalist json file."
|
| 54 |
)
|
| 55 |
+
parser.add_argument("--train_size", type=int, default=196, help="number of training samples.")
|
| 56 |
args = parser.parse_args()
|
| 57 |
|
| 58 |
main(args)
|
scripts/search.py
CHANGED
|
@@ -28,7 +28,7 @@ from monai import transforms
|
|
| 28 |
from monai.bundle import ConfigParser
|
| 29 |
from monai.data import ThreadDataLoader, partition_dataset
|
| 30 |
from monai.inferers import sliding_window_inference
|
| 31 |
-
from monai.metrics import
|
| 32 |
from monai.utils import set_determinism
|
| 33 |
from torch.nn.parallel import DistributedDataParallel
|
| 34 |
from torch.utils.tensorboard import SummaryWriter
|
|
@@ -100,14 +100,12 @@ def run(config_file: Union[str, Sequence[str]]):
|
|
| 100 |
train_files_w = partition_dataset(
|
| 101 |
data=train_files_w, shuffle=True, num_partitions=world_size, even_divisible=True
|
| 102 |
)[dist.get_rank()]
|
| 103 |
-
print("train_files_w:", len(train_files_w))
|
| 104 |
|
| 105 |
train_files_a = train_files[len(train_files) // 2 :]
|
| 106 |
if torch.cuda.device_count() > 1:
|
| 107 |
train_files_a = partition_dataset(
|
| 108 |
data=train_files_a, shuffle=True, num_partitions=world_size, even_divisible=True
|
| 109 |
)[dist.get_rank()]
|
| 110 |
-
print("train_files_a:", len(train_files_a))
|
| 111 |
|
| 112 |
# validation data
|
| 113 |
files = []
|
|
@@ -125,7 +123,6 @@ def run(config_file: Union[str, Sequence[str]]):
|
|
| 125 |
val_files = partition_dataset(data=val_files, shuffle=False, num_partitions=world_size, even_divisible=False)[
|
| 126 |
dist.get_rank()
|
| 127 |
]
|
| 128 |
-
print("val_files:", len(val_files))
|
| 129 |
|
| 130 |
# network architecture
|
| 131 |
if torch.cuda.device_count() > 1:
|
|
@@ -421,7 +418,7 @@ def run(config_file: Union[str, Sequence[str]]):
|
|
| 421 |
val_labels = post_label(val_labels[0, ...])
|
| 422 |
val_labels = val_labels[None, ...]
|
| 423 |
|
| 424 |
-
value =
|
| 425 |
|
| 426 |
print(_index + 1, "/", len(val_loader), value)
|
| 427 |
|
|
|
|
| 28 |
from monai.bundle import ConfigParser
|
| 29 |
from monai.data import ThreadDataLoader, partition_dataset
|
| 30 |
from monai.inferers import sliding_window_inference
|
| 31 |
+
from monai.metrics import compute_dice
|
| 32 |
from monai.utils import set_determinism
|
| 33 |
from torch.nn.parallel import DistributedDataParallel
|
| 34 |
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
| 100 |
train_files_w = partition_dataset(
|
| 101 |
data=train_files_w, shuffle=True, num_partitions=world_size, even_divisible=True
|
| 102 |
)[dist.get_rank()]
|
|
|
|
| 103 |
|
| 104 |
train_files_a = train_files[len(train_files) // 2 :]
|
| 105 |
if torch.cuda.device_count() > 1:
|
| 106 |
train_files_a = partition_dataset(
|
| 107 |
data=train_files_a, shuffle=True, num_partitions=world_size, even_divisible=True
|
| 108 |
)[dist.get_rank()]
|
|
|
|
| 109 |
|
| 110 |
# validation data
|
| 111 |
files = []
|
|
|
|
| 123 |
val_files = partition_dataset(data=val_files, shuffle=False, num_partitions=world_size, even_divisible=False)[
|
| 124 |
dist.get_rank()
|
| 125 |
]
|
|
|
|
| 126 |
|
| 127 |
# network architecture
|
| 128 |
if torch.cuda.device_count() > 1:
|
|
|
|
| 418 |
val_labels = post_label(val_labels[0, ...])
|
| 419 |
val_labels = val_labels[None, ...]
|
| 420 |
|
| 421 |
+
value = compute_dice(y_pred=val_outputs, y=val_labels, include_background=False)
|
| 422 |
|
| 423 |
print(_index + 1, "/", len(val_loader), value)
|
| 424 |
|