tomegg3 commited on
Commit
c2e0a0d
·
verified ·
1 Parent(s): 9bed453

Upload MPTS-52 checkpoints and training files

Browse files
EncDec-ODE-Gamma/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5fd2f9060e6de1382938db662f8d987aa38fa727cf9e96bd8f17571160290149
3
+ size 49644411
EncDec-ODE-Gamma/train.yaml ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
10
+ init_args:
11
+ interpolant:
12
+ class_path: omg.si.interpolants.PeriodicEncoderDecoderInterpolant
13
+ init_args:
14
+ switch_time: 0.6487086666110259
15
+ power: 1.0
16
+ gamma:
17
+ class_path: omg.si.gamma.LatentGammaEncoderDecoder
18
+ init_args:
19
+ a: 1.9883383838119686
20
+ switch_time: 0.6487086666110259
21
+ power: 1.0
22
+ epsilon: null
23
+ differential_equation_type: "ODE"
24
+ integrator_kwargs:
25
+ method: "euler"
26
+ velocity_annealing_factor: 12.290317841755964
27
+ correct_center_of_mass_motion: true
28
+ # lattice vectors
29
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
30
+ init_args:
31
+ interpolant: omg.si.interpolants.TrigonometricInterpolant
32
+ gamma:
33
+ class_path: omg.si.gamma.LatentGammaSqrt
34
+ init_args:
35
+ a: 0.21935645939922985
36
+ epsilon:
37
+ class_path: omg.si.epsilon.VanishingEpsilon
38
+ init_args:
39
+ c: 9.431054439782873
40
+ mu: 0.21809909486896933
41
+ sigma: 0.03292165737293197
42
+ differential_equation_type: "SDE"
43
+ integrator_kwargs:
44
+ method: "euler"
45
+ dt: 0.001218559336848557
46
+ velocity_annealing_factor: 4.302804708170181
47
+ correct_center_of_mass_motion: false
48
+ data_fields:
49
+ # if the order of the data_fields changes,
50
+ # the order of the above StochasticInterpolant inputs must also change
51
+ - "species"
52
+ - "pos"
53
+ - "cell"
54
+ integration_time_steps: 820
55
+ relative_si_costs:
56
+ species_loss: 0.0
57
+ pos_loss_b: 0.689192251322191
58
+ cell_loss_b: 0.12351464867571432
59
+ cell_loss_z: 0.18729310000209468
60
+ sampler:
61
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
62
+ init_args:
63
+ pos_distribution: null
64
+ cell_distribution:
65
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
66
+ init_args:
67
+ dataset_name: mpts_52
68
+ species_distribution:
69
+ class_path: omg.sampler.distributions.MirrorData
70
+ model:
71
+ class_path: omg.model.model.Model
72
+ init_args:
73
+ encoder:
74
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
75
+ head:
76
+ class_path: omg.model.heads.pass_through.PassThrough
77
+ time_embedder:
78
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
79
+ init_args:
80
+ dim: 256
81
+ use_min_perm_dist: False
82
+ float_32_matmul_precision: "high"
83
+ validation_mode: "match_rate"
84
+ number_cpus: 7
85
+ dataset_name: "mpts_52"
86
+ data:
87
+ train_dataset:
88
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
89
+ init_args:
90
+ dataset:
91
+ class_path: omg.datamodule.datamodule.DataModule
92
+ init_args:
93
+ lmdb_paths:
94
+ - "data/mpts_52/train.lmdb"
95
+ niggli: False
96
+ val_dataset:
97
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
98
+ init_args:
99
+ dataset:
100
+ class_path: omg.datamodule.datamodule.DataModule
101
+ init_args:
102
+ lmdb_paths:
103
+ - "data/mpts_52/val.lmdb"
104
+ niggli: False
105
+ predict_dataset:
106
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
107
+ init_args:
108
+ dataset:
109
+ class_path: omg.datamodule.datamodule.DataModule
110
+ init_args:
111
+ lmdb_paths:
112
+ - "data/mpts_52/test.lmdb"
113
+ niggli: False
114
+ batch_size: 256
115
+ num_workers: 4
116
+ pin_memory: True
117
+ persistent_workers: True
118
+ trainer:
119
+ callbacks:
120
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
121
+ init_args:
122
+ filename: "best_val_loss_total"
123
+ save_top_k: 1
124
+ monitor: "val_loss_total"
125
+ save_weights_only: true
126
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
127
+ init_args:
128
+ filename: "best_val_match_rate"
129
+ save_top_k: 1
130
+ monitor: "match_rate"
131
+ save_weights_only: true
132
+ mode: 'max'
133
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
134
+ init_args:
135
+ filename: "best_val_rmsd"
136
+ save_top_k: 1
137
+ monitor: "mean_rmsd"
138
+ save_weights_only: true
139
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
140
+ init_args:
141
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
142
+ monitor: "val_loss_total"
143
+ every_n_epochs: 100
144
+ save_weights_only: false
145
+ gradient_clip_val: 0.5
146
+ num_sanity_val_steps: 0
147
+ precision: "32-true"
148
+ max_epochs: 2000
149
+ enable_progress_bar: true
150
+ limit_val_batches: 0.5
151
+ check_val_every_n_epoch: 100
152
+ optimizer:
153
+ class_path: torch.optim.Adam
154
+ init_args:
155
+ lr: 0.00047748599389170053
EncDec-SDE-Gamma/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ecffeaba4cee77a76640b78e1954ee032f58e7e7b2662ed850dfa8ff5123b1fd
3
+ size 49644475
EncDec-SDE-Gamma/train.yaml ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
10
+ init_args:
11
+ interpolant:
12
+ class_path: omg.si.interpolants.PeriodicEncoderDecoderInterpolant
13
+ init_args:
14
+ switch_time: 0.42184997325946555
15
+ power: 0.5
16
+ gamma:
17
+ class_path: omg.si.gamma.LatentGammaEncoderDecoder
18
+ init_args:
19
+ a: 0.03989185248799893
20
+ switch_time: 0.42184997325946555
21
+ power: 0.5
22
+ epsilon:
23
+ class_path: omg.si.epsilon.VanishingEpsilon
24
+ init_args:
25
+ c: 2.3996529332194574
26
+ mu: 0.25251095399328916
27
+ sigma: 0.03759134500470063
28
+ differential_equation_type: "SDE"
29
+ integrator_kwargs:
30
+ method: "euler"
31
+ dt: 0.0014076164225116372
32
+ velocity_annealing_factor: 3.7755089557808477
33
+ correct_center_of_mass_motion: true
34
+ # lattice vectors
35
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
36
+ init_args:
37
+ interpolant: omg.si.interpolants.LinearInterpolant
38
+ gamma:
39
+ class_path: omg.si.gamma.LatentGammaSqrt
40
+ init_args:
41
+ a: 4.961271013084809
42
+ epsilon: null
43
+ differential_equation_type: "ODE"
44
+ integrator_kwargs:
45
+ method: "euler"
46
+ velocity_annealing_factor: 1.1379701544400436
47
+ correct_center_of_mass_motion: false
48
+ data_fields:
49
+ # if the order of the data_fields changes,
50
+ # the order of the above StochasticInterpolant inputs must also change
51
+ - "species"
52
+ - "pos"
53
+ - "cell"
54
+ integration_time_steps: 710
55
+ relative_si_costs:
56
+ species_loss: 0.0
57
+ pos_loss_b: 0.6143090042317803
58
+ pos_loss_z: 0.3794040725288834
59
+ cell_loss_b: 0.00628692323933625
60
+ sampler:
61
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
62
+ init_args:
63
+ pos_distribution: null
64
+ cell_distribution:
65
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
66
+ init_args:
67
+ dataset_name: mpts_52
68
+ species_distribution:
69
+ class_path: omg.sampler.distributions.MirrorData
70
+ model:
71
+ class_path: omg.model.model.Model
72
+ init_args:
73
+ encoder:
74
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
75
+ head:
76
+ class_path: omg.model.heads.pass_through.PassThrough
77
+ time_embedder:
78
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
79
+ init_args:
80
+ dim: 256
81
+ use_min_perm_dist: False
82
+ float_32_matmul_precision: "high"
83
+ validation_mode: "match_rate"
84
+ number_cpus: 7
85
+ dataset_name: "mpts_52"
86
+ data:
87
+ train_dataset:
88
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
89
+ init_args:
90
+ dataset:
91
+ class_path: omg.datamodule.datamodule.DataModule
92
+ init_args:
93
+ lmdb_paths:
94
+ - "data/mpts_52/train.lmdb"
95
+ niggli: False
96
+ val_dataset:
97
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
98
+ init_args:
99
+ dataset:
100
+ class_path: omg.datamodule.datamodule.DataModule
101
+ init_args:
102
+ lmdb_paths:
103
+ - "data/mpts_52/val.lmdb"
104
+ niggli: False
105
+ predict_dataset:
106
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
107
+ init_args:
108
+ dataset:
109
+ class_path: omg.datamodule.datamodule.DataModule
110
+ init_args:
111
+ lmdb_paths:
112
+ - "data/mpts_52/test.lmdb"
113
+ niggli: False
114
+ batch_size: 32
115
+ num_workers: 4
116
+ pin_memory: True
117
+ persistent_workers: True
118
+ trainer:
119
+ callbacks:
120
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
121
+ init_args:
122
+ filename: "best_val_loss_total"
123
+ save_top_k: 1
124
+ monitor: "val_loss_total"
125
+ save_weights_only: true
126
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
127
+ init_args:
128
+ filename: "best_val_match_rate"
129
+ save_top_k: 1
130
+ monitor: "match_rate"
131
+ save_weights_only: true
132
+ mode: 'max'
133
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
134
+ init_args:
135
+ filename: "best_val_rmsd"
136
+ save_top_k: 1
137
+ monitor: "mean_rmsd"
138
+ save_weights_only: true
139
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
140
+ init_args:
141
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
142
+ monitor: "val_loss_total"
143
+ every_n_epochs: 100
144
+ save_weights_only: false
145
+ gradient_clip_val: 0.5
146
+ num_sanity_val_steps: 0
147
+ precision: "32-true"
148
+ max_epochs: 2000
149
+ enable_progress_bar: true
150
+ limit_val_batches: 0.5
151
+ check_val_every_n_epoch: 100
152
+ optimizer:
153
+ class_path: torch.optim.Adam
154
+ init_args:
155
+ lr: 0.00018567271191860665
Linear-ODE-Gamma/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a2de33650021f1e54c181668bfb34de6b845c4b5ce6422abb4f44e85f31ebb5
3
+ size 49644411
Linear-ODE-Gamma/train.yaml ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
10
+ init_args:
11
+ interpolant: omg.si.interpolants.PeriodicLinearInterpolant
12
+ gamma:
13
+ class_path: omg.si.gamma.LatentGammaSqrt
14
+ init_args:
15
+ a: 0.2575112227566439
16
+ epsilon: null
17
+ differential_equation_type: "ODE"
18
+ integrator_kwargs:
19
+ method: "euler"
20
+ velocity_annealing_factor: 7.7611189744870925
21
+ correct_center_of_mass_motion: true
22
+ # lattice vectors
23
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
24
+ init_args:
25
+ interpolant: omg.si.interpolants.TrigonometricInterpolant
26
+ gamma:
27
+ class_path: omg.si.gamma.LatentGammaSqrt
28
+ init_args:
29
+ a: 2.9759856920732597
30
+ epsilon: null
31
+ differential_equation_type: "ODE"
32
+ integrator_kwargs:
33
+ method: "euler"
34
+ velocity_annealing_factor: 4.116061496782678
35
+ correct_center_of_mass_motion: false
36
+ data_fields:
37
+ # if the order of the data_fields changes,
38
+ # the order of the above StochasticInterpolant inputs must also change
39
+ - "species"
40
+ - "pos"
41
+ - "cell"
42
+ integration_time_steps: 690
43
+ relative_si_costs:
44
+ species_loss: 0.0
45
+ pos_loss_b: 0.9976417941296929
46
+ cell_loss_b: 0.002358205870307133
47
+ sampler:
48
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
49
+ init_args:
50
+ pos_distribution: null
51
+ cell_distribution:
52
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
53
+ init_args:
54
+ dataset_name: mpts_52
55
+ species_distribution:
56
+ class_path: omg.sampler.distributions.MirrorData
57
+ model:
58
+ class_path: omg.model.model.Model
59
+ init_args:
60
+ encoder:
61
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
62
+ head:
63
+ class_path: omg.model.heads.pass_through.PassThrough
64
+ time_embedder:
65
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
66
+ init_args:
67
+ dim: 256
68
+ use_min_perm_dist: False
69
+ float_32_matmul_precision: "high"
70
+ validation_mode: "match_rate"
71
+ number_cpus: 7
72
+ dataset_name: "mpts_52"
73
+ data:
74
+ train_dataset:
75
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
76
+ init_args:
77
+ dataset:
78
+ class_path: omg.datamodule.datamodule.DataModule
79
+ init_args:
80
+ lmdb_paths:
81
+ - "data/mpts_52/train.lmdb"
82
+ niggli: False
83
+ val_dataset:
84
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
85
+ init_args:
86
+ dataset:
87
+ class_path: omg.datamodule.datamodule.DataModule
88
+ init_args:
89
+ lmdb_paths:
90
+ - "data/mpts_52/val.lmdb"
91
+ niggli: False
92
+ predict_dataset:
93
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
94
+ init_args:
95
+ dataset:
96
+ class_path: omg.datamodule.datamodule.DataModule
97
+ init_args:
98
+ lmdb_paths:
99
+ - "data/mpts_52/test.lmdb"
100
+ niggli: False
101
+ batch_size: 128
102
+ num_workers: 4
103
+ pin_memory: True
104
+ persistent_workers: True
105
+ trainer:
106
+ callbacks:
107
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
108
+ init_args:
109
+ filename: "best_val_loss_total"
110
+ save_top_k: 1
111
+ monitor: "val_loss_total"
112
+ save_weights_only: true
113
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
114
+ init_args:
115
+ filename: "best_val_match_rate"
116
+ save_top_k: 1
117
+ monitor: "match_rate"
118
+ save_weights_only: true
119
+ mode: 'max'
120
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
121
+ init_args:
122
+ filename: "best_val_rmsd"
123
+ save_top_k: 1
124
+ monitor: "mean_rmsd"
125
+ save_weights_only: true
126
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
127
+ init_args:
128
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
129
+ monitor: "val_loss_total"
130
+ every_n_epochs: 100
131
+ save_weights_only: false
132
+ gradient_clip_val: 0.5
133
+ num_sanity_val_steps: 0
134
+ precision: "32-true"
135
+ max_epochs: 2000
136
+ enable_progress_bar: true
137
+ limit_val_batches: 0.5
138
+ check_val_every_n_epoch: 100
139
+ optimizer:
140
+ class_path: torch.optim.Adam
141
+ init_args:
142
+ lr: 4.006249666984122e-05
Linear-ODE/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c189017e4d30b78995435af4e8dc42ca8942f64633d662762928680c536e211
3
+ size 49644411
Linear-ODE/train.yaml ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
10
+ init_args:
11
+ interpolant: omg.si.interpolants.PeriodicLinearInterpolant
12
+ gamma: null
13
+ epsilon: null
14
+ differential_equation_type: "ODE"
15
+ integrator_kwargs:
16
+ method: "euler"
17
+ velocity_annealing_factor: 12.752963137656907
18
+ correct_center_of_mass_motion: true
19
+ # lattice vectors
20
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
21
+ init_args:
22
+ interpolant: omg.si.interpolants.LinearInterpolant
23
+ gamma: null
24
+ epsilon: null
25
+ differential_equation_type: "ODE"
26
+ integrator_kwargs:
27
+ method: "euler"
28
+ velocity_annealing_factor: 0.9964121490291458
29
+ correct_center_of_mass_motion: false
30
+ data_fields:
31
+ # if the order of the data_fields changes,
32
+ # the order of the above StochasticInterpolant inputs must also change
33
+ - "species"
34
+ - "pos"
35
+ - "cell"
36
+ integration_time_steps: 100
37
+ relative_si_costs:
38
+ species_loss: 0.0
39
+ pos_loss_b: 0.9983149306572928
40
+ cell_loss_b: 0.0016850693427072152
41
+ sampler:
42
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
43
+ init_args:
44
+ pos_distribution: null
45
+ cell_distribution:
46
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
47
+ init_args:
48
+ dataset_name: mpts_52
49
+ species_distribution:
50
+ class_path: omg.sampler.distributions.MirrorData
51
+ model:
52
+ class_path: omg.model.model.Model
53
+ init_args:
54
+ encoder:
55
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
56
+ head:
57
+ class_path: omg.model.heads.pass_through.PassThrough
58
+ time_embedder:
59
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
60
+ init_args:
61
+ dim: 256
62
+ use_min_perm_dist: False
63
+ float_32_matmul_precision: "high"
64
+ validation_mode: "match_rate"
65
+ dataset_name: "mpts_52"
66
+ data:
67
+ train_dataset:
68
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
69
+ init_args:
70
+ dataset:
71
+ class_path: omg.datamodule.datamodule.DataModule
72
+ init_args:
73
+ lmdb_paths:
74
+ - "data/mpts_52/train.lmdb"
75
+ niggli: False
76
+ val_dataset:
77
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
78
+ init_args:
79
+ dataset:
80
+ class_path: omg.datamodule.datamodule.DataModule
81
+ init_args:
82
+ lmdb_paths:
83
+ - "data/mpts_52/val.lmdb"
84
+ niggli: False
85
+ predict_dataset:
86
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
87
+ init_args:
88
+ dataset:
89
+ class_path: omg.datamodule.datamodule.DataModule
90
+ init_args:
91
+ lmdb_paths:
92
+ - "data/mpts_52/test.lmdb"
93
+ niggli: False
94
+ batch_size: 512
95
+ num_workers: 4
96
+ pin_memory: True
97
+ persistent_workers: True
98
+ trainer:
99
+ callbacks:
100
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
101
+ init_args:
102
+ filename: "best_val_loss_total"
103
+ save_top_k: 1
104
+ monitor: "val_loss_total"
105
+ save_weights_only: true
106
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
107
+ init_args:
108
+ filename: "best_val_match_rate"
109
+ save_top_k: 1
110
+ monitor: "match_rate"
111
+ save_weights_only: true
112
+ mode: 'max'
113
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
114
+ init_args:
115
+ filename: "best_val_rmsd"
116
+ save_top_k: 1
117
+ monitor: "mean_rmsd"
118
+ save_weights_only: true
119
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
120
+ init_args:
121
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
122
+ monitor: "val_loss_total"
123
+ every_n_epochs: 100
124
+ save_weights_only: false
125
+ gradient_clip_val: 0.5
126
+ num_sanity_val_steps: 0
127
+ precision: "32-true"
128
+ max_epochs: 10000
129
+ enable_progress_bar: false
130
+ check_val_every_n_epoch: 100
131
+ optimizer:
132
+ class_path: torch.optim.Adam
133
+ init_args:
134
+ lr: 0.0005546288717347031
Linear-SDE-Gamma/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a44fb11e6327471ca725a66137dee984529a626330cd137b51d6882ea39996f4
3
+ size 148120276
Linear-SDE-Gamma/train.yaml ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
10
+ init_args:
11
+ interpolant: omg.si.interpolants.PeriodicLinearInterpolant
12
+ gamma:
13
+ class_path: omg.si.gamma.LatentGammaSqrt
14
+ init_args:
15
+ a: 0.06285652866840548
16
+ epsilon:
17
+ class_path: omg.si.epsilon.VanishingEpsilon
18
+ init_args:
19
+ c: 6.097168392667226
20
+ mu: 0.21833859329765842
21
+ sigma: 0.04985718977712428
22
+ differential_equation_type: "SDE"
23
+ integrator_kwargs:
24
+ method: "euler"
25
+ dt: 0.0032297736033797264
26
+ velocity_annealing_factor: 11.58289329358004
27
+ correct_center_of_mass_motion: true
28
+ # lattice vectors
29
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
30
+ init_args:
31
+ interpolant: omg.si.interpolants.LinearInterpolant
32
+ gamma:
33
+ class_path: omg.si.gamma.LatentGammaSqrt
34
+ init_args:
35
+ a: 0.1317493001266121
36
+ epsilon:
37
+ class_path: omg.si.epsilon.VanishingEpsilon
38
+ init_args:
39
+ c: 9.612495617660462
40
+ mu: 0.08389382419092543
41
+ sigma: 0.033192886798663945
42
+ differential_equation_type: "SDE"
43
+ integrator_kwargs:
44
+ method: "euler"
45
+ dt: 0.0032297736033797264
46
+ velocity_annealing_factor: 5.081210983525862
47
+ correct_center_of_mass_motion: false
48
+ data_fields:
49
+ # if the order of the data_fields changes,
50
+ # the order of the above StochasticInterpolant inputs must also change
51
+ - "species"
52
+ - "pos"
53
+ - "cell"
54
+ integration_time_steps: 310
55
+ relative_si_costs:
56
+ species_loss: 0.0
57
+ pos_loss_b: 0.007345481151868809
58
+ pos_loss_z: 0.9153543617007412
59
+ cell_loss_b: 0.06421063793348068
60
+ cell_loss_z: 0.013089519213909303
61
+ sampler:
62
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
63
+ init_args:
64
+ pos_distribution: null
65
+ cell_distribution:
66
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
67
+ init_args:
68
+ dataset_name: mpts_52
69
+ species_distribution:
70
+ class_path: omg.sampler.distributions.MirrorData
71
+ model:
72
+ class_path: omg.model.model.Model
73
+ init_args:
74
+ encoder:
75
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
76
+ head:
77
+ class_path: omg.model.heads.pass_through.PassThrough
78
+ time_embedder:
79
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
80
+ init_args:
81
+ dim: 256
82
+ use_min_perm_dist: False
83
+ float_32_matmul_precision: "high"
84
+ validation_mode: "match_rate"
85
+ number_cpus: 7
86
+ dataset_name: "mpts_52"
87
+ data:
88
+ train_dataset:
89
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
90
+ init_args:
91
+ dataset:
92
+ class_path: omg.datamodule.datamodule.DataModule
93
+ init_args:
94
+ lmdb_paths:
95
+ - "data/mpts_52/train.lmdb"
96
+ niggli: False
97
+ val_dataset:
98
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
99
+ init_args:
100
+ dataset:
101
+ class_path: omg.datamodule.datamodule.DataModule
102
+ init_args:
103
+ lmdb_paths:
104
+ - "data/mpts_52/val.lmdb"
105
+ niggli: False
106
+ predict_dataset:
107
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
108
+ init_args:
109
+ dataset:
110
+ class_path: omg.datamodule.datamodule.DataModule
111
+ init_args:
112
+ lmdb_paths:
113
+ - "data/mpts_52/test.lmdb"
114
+ niggli: False
115
+ batch_size: 256
116
+ num_workers: 4
117
+ pin_memory: True
118
+ persistent_workers: True
119
+ trainer:
120
+ callbacks:
121
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
122
+ init_args:
123
+ filename: "best_val_loss_total"
124
+ save_top_k: 1
125
+ monitor: "val_loss_total"
126
+ save_weights_only: true
127
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
128
+ init_args:
129
+ filename: "best_val_match_rate"
130
+ save_top_k: 1
131
+ monitor: "match_rate"
132
+ save_weights_only: true
133
+ mode: 'max'
134
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
135
+ init_args:
136
+ filename: "best_val_rmsd"
137
+ save_top_k: 1
138
+ monitor: "mean_rmsd"
139
+ save_weights_only: true
140
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
141
+ init_args:
142
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
143
+ monitor: "val_loss_total"
144
+ every_n_epochs: 100
145
+ save_weights_only: false
146
+ gradient_clip_val: 0.5
147
+ num_sanity_val_steps: 0
148
+ precision: "32-true"
149
+ max_epochs: 2000
150
+ enable_progress_bar: true
151
+ limit_val_batches: 0.5
152
+ check_val_every_n_epoch: 100
153
+ optimizer:
154
+ class_path: torch.optim.Adam
155
+ init_args:
156
+ lr: 0.0002629870131361822
Trig-ODE-Gamma/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2300d4bfa2684ba81cdaac6277e4fb2205bb8b8c380022e1346169ffd9eda1fa
3
+ size 148107354
Trig-ODE-Gamma/train.yaml ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
10
+ init_args:
11
+ interpolant: omg.si.interpolants.PeriodicTrigonometricInterpolant
12
+ gamma:
13
+ class_path: omg.si.gamma.LatentGammaSqrt
14
+ init_args:
15
+ a: 0.03337798944475465
16
+ epsilon: null
17
+ differential_equation_type: "ODE"
18
+ integrator_kwargs:
19
+ method: "euler"
20
+ velocity_annealing_factor: 13.545929738762764
21
+ correct_center_of_mass_motion: true
22
+ # lattice vectors
23
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
24
+ init_args:
25
+ interpolant: omg.si.interpolants.LinearInterpolant
26
+ gamma:
27
+ class_path: omg.si.gamma.LatentGammaSqrt
28
+ init_args:
29
+ a: 0.017261010545698854
30
+ epsilon:
31
+ class_path: omg.si.epsilon.VanishingEpsilon
32
+ init_args:
33
+ c: 0.8758328635983847
34
+ mu: 0.29744423858325936
35
+ sigma: 0.0052236060273636595
36
+ differential_equation_type: "SDE"
37
+ integrator_kwargs:
38
+ method: "euler"
39
+ dt: 0.0012811297783628106
40
+ velocity_annealing_factor: 2.380421528846764
41
+ correct_center_of_mass_motion: false
42
+ data_fields:
43
+ # if the order of the data_fields changes,
44
+ # the order of the above StochasticInterpolant inputs must also change
45
+ - "species"
46
+ - "pos"
47
+ - "cell"
48
+ integration_time_steps: 780
49
+ relative_si_costs:
50
+ species_loss: 0.0
51
+ pos_loss_b: 0.983015308902659
52
+ cell_loss_b: 0.01673796318800159
53
+ cell_loss_z: 0.0002467279093394523
54
+ sampler:
55
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
56
+ init_args:
57
+ pos_distribution: null
58
+ cell_distribution:
59
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
60
+ init_args:
61
+ dataset_name: mpts_52
62
+ species_distribution:
63
+ class_path: omg.sampler.distributions.MirrorData
64
+ model:
65
+ class_path: omg.model.model.Model
66
+ init_args:
67
+ encoder:
68
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
69
+ head:
70
+ class_path: omg.model.heads.pass_through.PassThrough
71
+ time_embedder:
72
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
73
+ init_args:
74
+ dim: 256
75
+ use_min_perm_dist: False
76
+ float_32_matmul_precision: "high"
77
+ validation_mode: "match_rate"
78
+ number_cpus: 7
79
+ dataset_name: "mpts_52"
80
+ data:
81
+ train_dataset:
82
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
83
+ init_args:
84
+ dataset:
85
+ class_path: omg.datamodule.datamodule.DataModule
86
+ init_args:
87
+ lmdb_paths:
88
+ - "data/mpts_52/train.lmdb"
89
+ niggli: False
90
+ val_dataset:
91
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
92
+ init_args:
93
+ dataset:
94
+ class_path: omg.datamodule.datamodule.DataModule
95
+ init_args:
96
+ lmdb_paths:
97
+ - "data/mpts_52/val.lmdb"
98
+ niggli: False
99
+ predict_dataset:
100
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
101
+ init_args:
102
+ dataset:
103
+ class_path: omg.datamodule.datamodule.DataModule
104
+ init_args:
105
+ lmdb_paths:
106
+ - "data/mpts_52/test.lmdb"
107
+ niggli: False
108
+ batch_size: 32
109
+ num_workers: 4
110
+ pin_memory: True
111
+ persistent_workers: True
112
+ trainer:
113
+ callbacks:
114
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
115
+ init_args:
116
+ filename: "best_val_loss_total"
117
+ save_top_k: 1
118
+ monitor: "val_loss_total"
119
+ save_weights_only: true
120
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
121
+ init_args:
122
+ filename: "best_val_match_rate"
123
+ save_top_k: 1
124
+ monitor: "match_rate"
125
+ save_weights_only: true
126
+ mode: 'max'
127
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
128
+ init_args:
129
+ filename: "best_val_rmsd"
130
+ save_top_k: 1
131
+ monitor: "mean_rmsd"
132
+ save_weights_only: true
133
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
134
+ init_args:
135
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
136
+ monitor: "val_loss_total"
137
+ every_n_epochs: 100
138
+ save_weights_only: false
139
+ gradient_clip_val: 0.5
140
+ num_sanity_val_steps: 0
141
+ precision: "32-true"
142
+ max_epochs: 2000
143
+ enable_progress_bar: true
144
+ limit_val_batches: 0.5
145
+ check_val_every_n_epoch: 100
146
+ optimizer:
147
+ class_path: torch.optim.Adam
148
+ init_args:
149
+ lr: 8.341737878937152e-05
Trig-ODE/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4def0d650b785ca9fe3b818229d140d1885d1ff06bc64e75c15fbb0fb553f2e8
3
+ size 148107226
Trig-ODE/train.yaml ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
10
+ init_args:
11
+ interpolant: omg.si.interpolants.PeriodicTrigonometricInterpolant
12
+ gamma: null
13
+ epsilon: null
14
+ differential_equation_type: "ODE"
15
+ integrator_kwargs:
16
+ method: "euler"
17
+ velocity_annealing_factor: 12.34532470785473
18
+ correct_center_of_mass_motion: true
19
+ # lattice vectors
20
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
21
+ init_args:
22
+ interpolant:
23
+ class_path: omg.si.interpolants.EncoderDecoderInterpolant
24
+ init_args:
25
+ switch_time: 0.4080329374611481
26
+ power: 0.5
27
+ gamma:
28
+ class_path: omg.si.gamma.LatentGammaEncoderDecoder
29
+ init_args:
30
+ a: 5.270616141661882
31
+ switch_time: 0.4080329374611481
32
+ power: 0.5
33
+ epsilon:
34
+ class_path: omg.si.epsilon.VanishingEpsilon
35
+ init_args:
36
+ c: 4.354817546796119
37
+ mu: 0.2923928859901851
38
+ sigma: 0.04742031136770322
39
+ differential_equation_type: "SDE"
40
+ integrator_kwargs:
41
+ method: "euler"
42
+ dt: 0.005905325524508953
43
+ velocity_annealing_factor: 3.6141717997883447
44
+ correct_center_of_mass_motion: false
45
+ data_fields:
46
+ # if the order of the data_fields changes,
47
+ # the order of the above StochasticInterpolant inputs must also change
48
+ - "species"
49
+ - "pos"
50
+ - "cell"
51
+ integration_time_steps: 170
52
+ relative_si_costs:
53
+ species_loss: 0.0
54
+ pos_loss_b: 0.9967455480681945
55
+ cell_loss_b: 0.002271914623580616
56
+ cell_loss_z: 0.0009825373082248405
57
+ sampler:
58
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
59
+ init_args:
60
+ pos_distribution: null
61
+ cell_distribution:
62
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
63
+ init_args:
64
+ dataset_name: mpts_52
65
+ species_distribution:
66
+ class_path: omg.sampler.distributions.MirrorData
67
+ model:
68
+ class_path: omg.model.model.Model
69
+ init_args:
70
+ encoder:
71
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
72
+ head:
73
+ class_path: omg.model.heads.pass_through.PassThrough
74
+ time_embedder:
75
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
76
+ init_args:
77
+ dim: 256
78
+ use_min_perm_dist: False
79
+ float_32_matmul_precision: "high"
80
+ validation_mode: "match_rate"
81
+ number_cpus: 7
82
+ dataset_name: "mpts_52"
83
+ data:
84
+ train_dataset:
85
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
86
+ init_args:
87
+ dataset:
88
+ class_path: omg.datamodule.datamodule.DataModule
89
+ init_args:
90
+ lmdb_paths:
91
+ - "data/mpts_52/train.lmdb"
92
+ niggli: False
93
+ val_dataset:
94
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
95
+ init_args:
96
+ dataset:
97
+ class_path: omg.datamodule.datamodule.DataModule
98
+ init_args:
99
+ lmdb_paths:
100
+ - "data/mpts_52/val.lmdb"
101
+ niggli: False
102
+ predict_dataset:
103
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
104
+ init_args:
105
+ dataset:
106
+ class_path: omg.datamodule.datamodule.DataModule
107
+ init_args:
108
+ lmdb_paths:
109
+ - "data/mpts_52/test.lmdb"
110
+ niggli: False
111
+ batch_size: 32
112
+ num_workers: 4
113
+ pin_memory: True
114
+ persistent_workers: True
115
+ trainer:
116
+ callbacks:
117
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
118
+ init_args:
119
+ filename: "best_val_loss_total"
120
+ save_top_k: 1
121
+ monitor: "val_loss_total"
122
+ save_weights_only: true
123
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
124
+ init_args:
125
+ filename: "best_val_match_rate"
126
+ save_top_k: 1
127
+ monitor: "match_rate"
128
+ save_weights_only: true
129
+ mode: 'max'
130
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
131
+ init_args:
132
+ filename: "best_val_rmsd"
133
+ save_top_k: 1
134
+ monitor: "mean_rmsd"
135
+ save_weights_only: true
136
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
137
+ init_args:
138
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
139
+ monitor: "val_loss_total"
140
+ every_n_epochs: 100
141
+ save_weights_only: false
142
+ gradient_clip_val: 0.5
143
+ num_sanity_val_steps: 0
144
+ precision: "32-true"
145
+ max_epochs: 2000
146
+ enable_progress_bar: true
147
+ limit_val_batches: 0.5
148
+ check_val_every_n_epoch: 100
149
+ optimizer:
150
+ class_path: torch.optim.Adam
151
+ init_args:
152
+ lr: 3.629490873183724e-05
Trig-SDE-Gamma/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7fd313b4ce18394628596dce790ba5a6e58bd37a1fe5f7c0c78be6a2e3ab0fb5
3
+ size 148082778
Trig-SDE-Gamma/train.yaml ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
10
+ init_args:
11
+ interpolant: omg.si.interpolants.PeriodicTrigonometricInterpolant
12
+ gamma:
13
+ class_path: omg.si.gamma.LatentGammaSqrt
14
+ init_args:
15
+ a: 0.049242906264339095
16
+ epsilon:
17
+ class_path: omg.si.epsilon.VanishingEpsilon
18
+ init_args:
19
+ c: 9.418703639528207
20
+ mu: 0.1967838464371502
21
+ sigma: 0.040028404066547216
22
+ differential_equation_type: "SDE"
23
+ integrator_kwargs:
24
+ method: "euler"
25
+ dt: 0.0013504737289622426
26
+ velocity_annealing_factor: 11.483173553510193
27
+ correct_center_of_mass_motion: true
28
+ # lattice vectors
29
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
30
+ init_args:
31
+ interpolant: omg.si.interpolants.TrigonometricInterpolant
32
+ gamma: null
33
+ epsilon: null
34
+ differential_equation_type: "ODE"
35
+ integrator_kwargs:
36
+ method: "euler"
37
+ velocity_annealing_factor: 0.4337356395028541
38
+ correct_center_of_mass_motion: false
39
+ data_fields:
40
+ # if the order of the data_fields changes,
41
+ # the order of the above StochasticInterpolant inputs must also change
42
+ - "species"
43
+ - "pos"
44
+ - "cell"
45
+ integration_time_steps: 740
46
+ relative_si_costs:
47
+ species_loss: 0.0
48
+ pos_loss_b: 0.24677273761024368
49
+ pos_loss_z: 0.7231540118244248
50
+ cell_loss_b: 0.030073250565331323
51
+ sampler:
52
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
53
+ init_args:
54
+ pos_distribution: null
55
+ cell_distribution:
56
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
57
+ init_args:
58
+ dataset_name: mpts_52
59
+ species_distribution:
60
+ class_path: omg.sampler.distributions.MirrorData
61
+ model:
62
+ class_path: omg.model.model.Model
63
+ init_args:
64
+ encoder:
65
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
66
+ head:
67
+ class_path: omg.model.heads.pass_through.PassThrough
68
+ time_embedder:
69
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
70
+ init_args:
71
+ dim: 256
72
+ use_min_perm_dist: True
73
+ float_32_matmul_precision: "high"
74
+ validation_mode: "match_rate"
75
+ number_cpus: 7
76
+ dataset_name: "mpts_52"
77
+ data:
78
+ train_dataset:
79
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
80
+ init_args:
81
+ dataset:
82
+ class_path: omg.datamodule.datamodule.DataModule
83
+ init_args:
84
+ lmdb_paths:
85
+ - "data/mpts_52/train.lmdb"
86
+ niggli: False
87
+ val_dataset:
88
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
89
+ init_args:
90
+ dataset:
91
+ class_path: omg.datamodule.datamodule.DataModule
92
+ init_args:
93
+ lmdb_paths:
94
+ - "data/mpts_52/val.lmdb"
95
+ niggli: False
96
+ predict_dataset:
97
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
98
+ init_args:
99
+ dataset:
100
+ class_path: omg.datamodule.datamodule.DataModule
101
+ init_args:
102
+ lmdb_paths:
103
+ - "data/mpts_52/test.lmdb"
104
+ niggli: False
105
+ batch_size: 32
106
+ num_workers: 4
107
+ pin_memory: True
108
+ persistent_workers: True
109
+ trainer:
110
+ callbacks:
111
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
112
+ init_args:
113
+ filename: "best_val_loss_total"
114
+ save_top_k: 1
115
+ monitor: "val_loss_total"
116
+ save_weights_only: true
117
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
118
+ init_args:
119
+ filename: "best_val_match_rate"
120
+ save_top_k: 1
121
+ monitor: "match_rate"
122
+ save_weights_only: true
123
+ mode: 'max'
124
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
125
+ init_args:
126
+ filename: "best_val_rmsd"
127
+ save_top_k: 1
128
+ monitor: "mean_rmsd"
129
+ save_weights_only: true
130
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
131
+ init_args:
132
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
133
+ monitor: "val_loss_total"
134
+ every_n_epochs: 100
135
+ save_weights_only: false
136
+ gradient_clip_val: 0.5
137
+ num_sanity_val_steps: 0
138
+ precision: "32-true"
139
+ max_epochs: 2000
140
+ enable_progress_bar: true
141
+ limit_val_batches: 0.5
142
+ check_val_every_n_epoch: 100
143
+ optimizer:
144
+ class_path: torch.optim.Adam
145
+ init_args:
146
+ lr: 9.320780466656964e-05
VESBD-ODE/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64cb0f57a91fd0d53d8710f6a6e5e86baefc4ee18a5bf7c0db7e74aad10e89b5
3
+ size 49644411
VESBD-ODE/train.yaml ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant_os.SingleStochasticInterpolantOS
10
+ init_args:
11
+ interpolant:
12
+ class_path: omg.si.interpolants.PeriodicScoreBasedDiffusionModelInterpolantVE
13
+ init_args:
14
+ sigma:
15
+ class_path: omg.si.sigma.GeometricSigma
16
+ init_args:
17
+ sigma_min: 0.004705415831077799
18
+ sigma_max: 0.9967130801483843
19
+ epsilon: null
20
+ differential_equation_type: "ODE"
21
+ integrator_kwargs:
22
+ method: "euler"
23
+ velocity_annealing_factor: 8.284579088906593
24
+ correct_center_of_mass_motion: true
25
+ predict_velocity: true
26
+ # lattice vectors
27
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
28
+ init_args:
29
+ interpolant: omg.si.interpolants.LinearInterpolant
30
+ gamma:
31
+ class_path: omg.si.gamma.LatentGammaSqrt
32
+ init_args:
33
+ a: 0.016616684357970132
34
+ epsilon:
35
+ class_path: omg.si.epsilon.VanishingEpsilon
36
+ init_args:
37
+ c: 3.9372558236242052
38
+ mu: 0.2649556265396099
39
+ sigma: 0.03578203230805775
40
+ differential_equation_type: "SDE"
41
+ integrator_kwargs:
42
+ method: "euler"
43
+ dt: 0.0015144158387556672
44
+ velocity_annealing_factor: 0.42775377056075214
45
+ correct_center_of_mass_motion: false
46
+ data_fields:
47
+ # if the order of the data_fields changes,
48
+ # the order of the above StochasticInterpolant inputs must also change
49
+ - "species"
50
+ - "pos"
51
+ - "cell"
52
+ integration_time_steps: 660
53
+ relative_si_costs:
54
+ species_loss: 0.0
55
+ pos_loss_b: 0.9813067351598369
56
+ cell_loss_b: 0.0005256953168558359
57
+ cell_loss_z: 0.018167569523307267
58
+ sampler:
59
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
60
+ init_args:
61
+ pos_distribution:
62
+ class_path: omg.sampler.distributions.NormalDistribution
63
+ init_args:
64
+ scale: 9.77149759679434
65
+ cell_distribution:
66
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
67
+ init_args:
68
+ dataset_name: mpts_52
69
+ species_distribution:
70
+ class_path: omg.sampler.distributions.MirrorData
71
+ model:
72
+ class_path: omg.model.model.Model
73
+ init_args:
74
+ encoder:
75
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
76
+ head:
77
+ class_path: omg.model.heads.pass_through.PassThrough
78
+ time_embedder:
79
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
80
+ init_args:
81
+ dim: 256
82
+ use_min_perm_dist: False
83
+ float_32_matmul_precision: "high"
84
+ validation_mode: "match_rate"
85
+ number_cpus: 7
86
+ dataset_name: "mpts_52"
87
+ data:
88
+ train_dataset:
89
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
90
+ init_args:
91
+ dataset:
92
+ class_path: omg.datamodule.datamodule.DataModule
93
+ init_args:
94
+ lmdb_paths:
95
+ - "data/mpts_52/train.lmdb"
96
+ niggli: False
97
+ val_dataset:
98
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
99
+ init_args:
100
+ dataset:
101
+ class_path: omg.datamodule.datamodule.DataModule
102
+ init_args:
103
+ lmdb_paths:
104
+ - "data/mpts_52/val.lmdb"
105
+ niggli: False
106
+ predict_dataset:
107
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
108
+ init_args:
109
+ dataset:
110
+ class_path: omg.datamodule.datamodule.DataModule
111
+ init_args:
112
+ lmdb_paths:
113
+ - "data/mpts_52/test.lmdb"
114
+ niggli: False
115
+ batch_size: 256
116
+ num_workers: 4
117
+ pin_memory: True
118
+ persistent_workers: True
119
+ trainer:
120
+ callbacks:
121
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
122
+ init_args:
123
+ filename: "best_val_loss_total"
124
+ save_top_k: 1
125
+ monitor: "val_loss_total"
126
+ save_weights_only: true
127
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
128
+ init_args:
129
+ filename: "best_val_match_rate"
130
+ save_top_k: 1
131
+ monitor: "match_rate"
132
+ save_weights_only: true
133
+ mode: 'max'
134
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
135
+ init_args:
136
+ filename: "best_val_rmsd"
137
+ save_top_k: 1
138
+ monitor: "mean_rmsd"
139
+ save_weights_only: true
140
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
141
+ init_args:
142
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
143
+ monitor: "val_loss_total"
144
+ every_n_epochs: 100
145
+ save_weights_only: false
146
+ gradient_clip_val: 0.5
147
+ num_sanity_val_steps: 0
148
+ precision: "32-true"
149
+ max_epochs: 2000
150
+ enable_progress_bar: true
151
+ limit_val_batches: 0.5
152
+ check_val_every_n_epoch: 100
153
+ optimizer:
154
+ class_path: torch.optim.Adam
155
+ init_args:
156
+ lr: 0.000296636127734534
VPSBD-ODE/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f39a43f38e06f45c8185d13fee14d4d14687b6979120958a9640c461bd0e6181
3
+ size 148069280
VPSBD-ODE/train.yaml ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant_os.SingleStochasticInterpolantOS
10
+ init_args:
11
+ interpolant: omg.si.interpolants.PeriodicScoreBasedDiffusionModelInterpolant
12
+ epsilon: null
13
+ differential_equation_type: "ODE"
14
+ integrator_kwargs:
15
+ method: "euler"
16
+ velocity_annealing_factor: 6.613808424917352
17
+ correct_center_of_mass_motion: true
18
+ predict_velocity: true
19
+ # lattice vectors
20
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
21
+ init_args:
22
+ interpolant: omg.si.interpolants.LinearInterpolant
23
+ gamma: null
24
+ epsilon: null
25
+ differential_equation_type: "ODE"
26
+ integrator_kwargs:
27
+ method: "euler"
28
+ velocity_annealing_factor: 2.447993013544224
29
+ correct_center_of_mass_motion: false
30
+ data_fields:
31
+ # if the order of the data_fields changes,
32
+ # the order of the above StochasticInterpolant inputs must also change
33
+ - "species"
34
+ - "pos"
35
+ - "cell"
36
+ integration_time_steps: 890
37
+ relative_si_costs:
38
+ species_loss: 0.0
39
+ pos_loss_b: 0.9597565150933746
40
+ cell_loss_b: 0.04024348490662539
41
+ sampler:
42
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
43
+ init_args:
44
+ pos_distribution:
45
+ class_path: omg.sampler.distributions.NormalDistribution
46
+ init_args:
47
+ scale: 0.22006712732536396
48
+ cell_distribution:
49
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
50
+ init_args:
51
+ dataset_name: mpts_52
52
+ species_distribution:
53
+ class_path: omg.sampler.distributions.MirrorData
54
+ model:
55
+ class_path: omg.model.model.Model
56
+ init_args:
57
+ encoder:
58
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
59
+ head:
60
+ class_path: omg.model.heads.pass_through.PassThrough
61
+ time_embedder:
62
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
63
+ init_args:
64
+ dim: 256
65
+ use_min_perm_dist: True
66
+ float_32_matmul_precision: "high"
67
+ validation_mode: "match_rate"
68
+ number_cpus: 7
69
+ dataset_name: "mpts_52"
70
+ data:
71
+ train_dataset:
72
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
73
+ init_args:
74
+ dataset:
75
+ class_path: omg.datamodule.datamodule.DataModule
76
+ init_args:
77
+ lmdb_paths:
78
+ - "data/mpts_52/train.lmdb"
79
+ niggli: False
80
+ val_dataset:
81
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
82
+ init_args:
83
+ dataset:
84
+ class_path: omg.datamodule.datamodule.DataModule
85
+ init_args:
86
+ lmdb_paths:
87
+ - "data/mpts_52/val.lmdb"
88
+ niggli: False
89
+ predict_dataset:
90
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
91
+ init_args:
92
+ dataset:
93
+ class_path: omg.datamodule.datamodule.DataModule
94
+ init_args:
95
+ lmdb_paths:
96
+ - "data/mpts_52/test.lmdb"
97
+ niggli: False
98
+ batch_size: 64
99
+ num_workers: 4
100
+ pin_memory: True
101
+ persistent_workers: True
102
+ trainer:
103
+ callbacks:
104
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
105
+ init_args:
106
+ filename: "best_val_loss_total"
107
+ save_top_k: 1
108
+ monitor: "val_loss_total"
109
+ save_weights_only: true
110
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
111
+ init_args:
112
+ filename: "best_val_match_rate"
113
+ save_top_k: 1
114
+ monitor: "match_rate"
115
+ save_weights_only: true
116
+ mode: 'max'
117
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
118
+ init_args:
119
+ filename: "best_val_rmsd"
120
+ save_top_k: 1
121
+ monitor: "mean_rmsd"
122
+ save_weights_only: true
123
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
124
+ init_args:
125
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
126
+ monitor: "val_loss_total"
127
+ every_n_epochs: 100
128
+ save_weights_only: false
129
+ gradient_clip_val: 0.5
130
+ num_sanity_val_steps: 0
131
+ precision: "32-true"
132
+ max_epochs: 2000
133
+ enable_progress_bar: true
134
+ limit_val_batches: 0.5
135
+ check_val_every_n_epoch: 100
136
+ optimizer:
137
+ class_path: torch.optim.Adam
138
+ init_args:
139
+ lr: 2.519765029616902e-05
VPSBD-SDE/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2d49763d63ce905025b23efa6babcc6ab237cb1684daec33f055b4452dd3825c
3
+ size 49644475
VPSBD-SDE/train.yaml ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant_os.SingleStochasticInterpolantOS
10
+ init_args:
11
+ interpolant: omg.si.interpolants.PeriodicScoreBasedDiffusionModelInterpolant
12
+ epsilon:
13
+ class_path: omg.si.epsilon.VanishingEpsilon
14
+ init_args:
15
+ c: 2.4729222108905815
16
+ mu: 0.17656358406313838
17
+ sigma: 0.02379822283154629
18
+ differential_equation_type: "SDE"
19
+ integrator_kwargs:
20
+ method: "euler"
21
+ dt: 0.0016661101253703237
22
+ velocity_annealing_factor: 6.459028320375323
23
+ correct_center_of_mass_motion: true
24
+ predict_velocity: true
25
+ # lattice vectors
26
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
27
+ init_args:
28
+ interpolant: omg.si.interpolants.LinearInterpolant
29
+ gamma:
30
+ class_path: omg.si.gamma.LatentGammaSqrt
31
+ init_args:
32
+ a: 3.683542379054881
33
+ epsilon: null
34
+ differential_equation_type: "ODE"
35
+ integrator_kwargs:
36
+ method: "euler"
37
+ velocity_annealing_factor: 0.6692350794589719
38
+ correct_center_of_mass_motion: false
39
+ data_fields:
40
+ # if the order of the data_fields changes,
41
+ # the order of the above StochasticInterpolant inputs must also change
42
+ - "species"
43
+ - "pos"
44
+ - "cell"
45
+ integration_time_steps: 600
46
+ relative_si_costs:
47
+ species_loss: 0.0
48
+ pos_loss_b: 0.6060249654155797
49
+ pos_loss_z: 0.3828230559814603
50
+ cell_loss_b: 0.011151978602959979
51
+ sampler:
52
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
53
+ init_args:
54
+ pos_distribution:
55
+ class_path: omg.sampler.distributions.NormalDistribution
56
+ init_args:
57
+ scale: 2.2937003279036148
58
+ cell_distribution:
59
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
60
+ init_args:
61
+ dataset_name: mpts_52
62
+ species_distribution:
63
+ class_path: omg.sampler.distributions.MirrorData
64
+ model:
65
+ class_path: omg.model.model.Model
66
+ init_args:
67
+ encoder:
68
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
69
+ head:
70
+ class_path: omg.model.heads.pass_through.PassThrough
71
+ time_embedder:
72
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
73
+ init_args:
74
+ dim: 256
75
+ use_min_perm_dist: True
76
+ float_32_matmul_precision: "high"
77
+ validation_mode: "match_rate"
78
+ number_cpus: 7
79
+ dataset_name: "mpts_52"
80
+ data:
81
+ train_dataset:
82
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
83
+ init_args:
84
+ dataset:
85
+ class_path: omg.datamodule.datamodule.DataModule
86
+ init_args:
87
+ lmdb_paths:
88
+ - "data/mpts_52/train.lmdb"
89
+ niggli: False
90
+ val_dataset:
91
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
92
+ init_args:
93
+ dataset:
94
+ class_path: omg.datamodule.datamodule.DataModule
95
+ init_args:
96
+ lmdb_paths:
97
+ - "data/mpts_52/val.lmdb"
98
+ niggli: False
99
+ predict_dataset:
100
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
101
+ init_args:
102
+ dataset:
103
+ class_path: omg.datamodule.datamodule.DataModule
104
+ init_args:
105
+ lmdb_paths:
106
+ - "data/mpts_52/test.lmdb"
107
+ niggli: False
108
+ batch_size: 64
109
+ num_workers: 4
110
+ pin_memory: True
111
+ persistent_workers: True
112
+ trainer:
113
+ callbacks:
114
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
115
+ init_args:
116
+ filename: "best_val_loss_total"
117
+ save_top_k: 1
118
+ monitor: "val_loss_total"
119
+ save_weights_only: true
120
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
121
+ init_args:
122
+ filename: "best_val_match_rate"
123
+ save_top_k: 1
124
+ monitor: "match_rate"
125
+ save_weights_only: true
126
+ mode: 'max'
127
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
128
+ init_args:
129
+ filename: "best_val_rmsd"
130
+ save_top_k: 1
131
+ monitor: "mean_rmsd"
132
+ save_weights_only: true
133
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
134
+ init_args:
135
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
136
+ monitor: "val_loss_total"
137
+ every_n_epochs: 100
138
+ save_weights_only: false
139
+ gradient_clip_val: 0.5
140
+ num_sanity_val_steps: 0
141
+ precision: "32-true"
142
+ max_epochs: 2000
143
+ enable_progress_bar: true
144
+ limit_val_batches: 0.5
145
+ check_val_every_n_epoch: 100
146
+ optimizer:
147
+ class_path: torch.optim.Adam
148
+ init_args:
149
+ lr: 0.0003030820420973639