Gravitation Example

import numpy as np
import os
import tensorflow as tf
import torch
import maxent
from sbi_gravitation import GravitySimulator, sim_wrapper, get_observation_points
from torch.distributions.multivariate_normal import MultivariateNormal
from sbi.inference import infer
import scipy

import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import seaborn as sns
from functools import partialmethod
from tqdm import tqdm

tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)

np.random.seed(12656)
sns.set_context("paper")
sns.set_style(
    "white",
    {
        "xtick.bottom": True,
        "ytick.left": True,
        "xtick.color": "#333333",
        "ytick.color": "#333333",
    },
)
plt.rcParams["mathtext.fontset"] = "dejavuserif"
colors = ["#1b9e77", "#d95f02", "#7570b3", "#e7298a", "#66a61e"]
# set up true parameters
m1 = 100.0  # solar masses
m2 = 50.0  # solar masses
m3 = 75  # solar masses
G = 1.90809e5  # solar radius / solar mass * (km/s)^2
v0 = np.array([15.0, -40.0])  # km/s

true_params = [m1, m2, m3, v0[0], v0[1]]

# set prior means
prior_means = [85.0, 40.0, 70.0, 12.0, -30.0]
prior_cov = np.eye(5) * 50
# generate true trajectory and apply some noise to it
if os.path.exists("true_trajectory.txt"):
    true_traj = np.genfromtxt("true_trajectory.txt")
else:
    sim = GravitySimulator(m1, m2, m3, v0, G, random_noise=False)
    true_traj = sim.run()
    np.savetxt("true_trajectory.txt", true_traj)

if os.path.exists("noisy_trajectory.txt"):
    noisy_traj = np.genfromtxt("noisy_trajectory.txt")
else:
    sim = GravitySimulator(m1, m2, m3, v0, G, random_noise=True)
    noisy_traj = sim.run()
    np.savetxt("noisy_trajectory.txt", noisy_traj)

observed_points = get_observation_points(noisy_traj)
observation_summary_stats = observed_points.flatten()
sim = GravitySimulator(m1, m2, m3, v0, G, random_noise=False)
sim.run()
sim.plot_traj()
../_images/gravitation_3_0.png
# perform SNL inference
prior = MultivariateNormal(
    loc=torch.as_tensor(prior_means),
    covariance_matrix=torch.as_tensor(torch.eye(5) * 50),
)

posterior = infer(
    sim_wrapper, prior, method="SNLE", num_simulations=2048, num_workers=16
)
 Training neural network. Epochs trained: 1
 Training neural network. Epochs trained: 2
 Training neural network. Epochs trained: 3
 Training neural network. Epochs trained: 4
 Training neural network. Epochs trained: 5
 Training neural network. Epochs trained: 6
 Training neural network. Epochs trained: 7
 Training neural network. Epochs trained: 8
 Training neural network. Epochs trained: 9
 Training neural network. Epochs trained: 10
 Training neural network. Epochs trained: 11
 Training neural network. Epochs trained: 12
 Training neural network. Epochs trained: 13
 Training neural network. Epochs trained: 14
 Training neural network. Epochs trained: 15
 Training neural network. Epochs trained: 16
 Training neural network. Epochs trained: 17
 Training neural network. Epochs trained: 18
 Training neural network. Epochs trained: 19
 Training neural network. Epochs trained: 20
 Training neural network. Epochs trained: 21
 Training neural network. Epochs trained: 22
 Training neural network. Epochs trained: 23
 Training neural network. Epochs trained: 24
 Training neural network. Epochs trained: 25
 Training neural network. Epochs trained: 26
 Training neural network. Epochs trained: 27
 Training neural network. Epochs trained: 28
 Training neural network. Epochs trained: 29
 Training neural network. Epochs trained: 30
 Training neural network. Epochs trained: 31
 Training neural network. Epochs trained: 32
 Training neural network. Epochs trained: 33
 Training neural network. Epochs trained: 34
 Training neural network. Epochs trained: 35
 Training neural network. Epochs trained: 36
 Training neural network. Epochs trained: 37
 Training neural network. Epochs trained: 38
 Training neural network. Epochs trained: 39
 Training neural network. Epochs trained: 40
 Training neural network. Epochs trained: 41
 Training neural network. Epochs trained: 42
 Training neural network. Epochs trained: 43
 Training neural network. Epochs trained: 44
 Training neural network. Epochs trained: 45
 Training neural network. Epochs trained: 46
 Training neural network. Epochs trained: 47
 Training neural network. Epochs trained: 48
 Training neural network. Epochs trained: 49
 Training neural network. Epochs trained: 50
 Training neural network. Epochs trained: 51
 Training neural network. Epochs trained: 52
 Training neural network. Epochs trained: 53
 Training neural network. Epochs trained: 54
 Training neural network. Epochs trained: 55
 Training neural network. Epochs trained: 56
 Training neural network. Epochs trained: 57
 Training neural network. Epochs trained: 58
 Training neural network. Epochs trained: 59
 Training neural network. Epochs trained: 60
 Training neural network. Epochs trained: 61
 Training neural network. Epochs trained: 62
 Training neural network. Epochs trained: 63
 Training neural network. Epochs trained: 64
 Training neural network. Epochs trained: 65
 Training neural network. Epochs trained: 66
 Training neural network. Epochs trained: 67
 Training neural network. Epochs trained: 68
 Training neural network. Epochs trained: 69
 Training neural network. Epochs trained: 70
 Training neural network. Epochs trained: 71
 Training neural network. Epochs trained: 72
 Training neural network. Epochs trained: 73
 Training neural network. Epochs trained: 74
 Training neural network. Epochs trained: 75
 Training neural network. Epochs trained: 76
 Training neural network. Epochs trained: 77
 Training neural network. Epochs trained: 78
 Training neural network. Epochs trained: 79
 Training neural network. Epochs trained: 80
 Training neural network. Epochs trained: 81
 Training neural network. Epochs trained: 82
 Training neural network. Epochs trained: 83
 Training neural network. Epochs trained: 84
 Training neural network. Epochs trained: 85
 Training neural network. Epochs trained: 86
 Training neural network. Epochs trained: 87
 Training neural network. Epochs trained: 88
 Training neural network. Epochs trained: 89
 Training neural network. Epochs trained: 90
 Training neural network. Epochs trained: 91
 Training neural network. Epochs trained: 92
 Training neural network. Epochs trained: 93
 Training neural network. Epochs trained: 94
 Training neural network. Epochs trained: 95
 Training neural network. Epochs trained: 96
 Training neural network. Epochs trained: 97
 Training neural network. Epochs trained: 98
 Training neural network. Epochs trained: 99
 Training neural network. Epochs trained: 100
 Training neural network. Epochs trained: 101
 Training neural network. Epochs trained: 102
 Training neural network. Epochs trained: 103
 Training neural network. Epochs trained: 104
 Training neural network. Epochs trained: 105
 Training neural network. Epochs trained: 106
 Training neural network. Epochs trained: 107
 Training neural network. Epochs trained: 108
 Training neural network. Epochs trained: 109
 Training neural network. Epochs trained: 110
 Training neural network. Epochs trained: 111
 Training neural network. Epochs trained: 112
 Training neural network. Epochs trained: 113
 Training neural network. Epochs trained: 114
 Training neural network. Epochs trained: 115
 Training neural network. Epochs trained: 116
 Training neural network. Epochs trained: 117
 Training neural network. Epochs trained: 118
 Training neural network. Epochs trained: 119
 Training neural network. Epochs trained: 120
 Training neural network. Epochs trained: 121
 Training neural network. Epochs trained: 122
 Training neural network. Epochs trained: 123
 Training neural network. Epochs trained: 124
 Training neural network. Epochs trained: 125
 Training neural network. Epochs trained: 126
 Training neural network. Epochs trained: 127
 Training neural network. Epochs trained: 128
 Training neural network. Epochs trained: 129
 Training neural network. Epochs trained: 130
 Training neural network. Epochs trained: 131
 Training neural network. Epochs trained: 132
 Training neural network. Epochs trained: 133
 Training neural network. Epochs trained: 134
 Training neural network. Epochs trained: 135
 Training neural network. Epochs trained: 136
 Training neural network. Epochs trained: 137
 Training neural network. Epochs trained: 138
 Training neural network. Epochs trained: 139
 Training neural network. Epochs trained: 140
 Training neural network. Epochs trained: 141
 Training neural network. Epochs trained: 142
 Training neural network. Epochs trained: 143
 Training neural network. Epochs trained: 144
 Training neural network. Epochs trained: 145
 Training neural network. Epochs trained: 146
 Training neural network. Epochs trained: 147
 Training neural network. Epochs trained: 148
 Training neural network. Epochs trained: 149
 Training neural network. Epochs trained: 150
 Training neural network. Epochs trained: 151
 Training neural network. Epochs trained: 152
 Training neural network. Epochs trained: 153
 Training neural network. Epochs trained: 154
 Training neural network. Epochs trained: 155
 Training neural network. Epochs trained: 156
 Training neural network. Epochs trained: 157
 Training neural network. Epochs trained: 158
 Training neural network. Epochs trained: 159
 Training neural network. Epochs trained: 160
 Training neural network. Epochs trained: 161
 Training neural network. Epochs trained: 162
 Training neural network. Epochs trained: 163
 Training neural network. Epochs trained: 164
 Training neural network. Epochs trained: 165
 Training neural network. Epochs trained: 166
 Training neural network. Epochs trained: 167
 Training neural network. Epochs trained: 168
 Training neural network. Epochs trained: 169
 Training neural network. Epochs trained: 170
 Training neural network. Epochs trained: 171
 Training neural network. Epochs trained: 172
 Training neural network. Epochs trained: 173
 Training neural network. Epochs trained: 174
 Training neural network. Epochs trained: 175
 Training neural network. Epochs trained: 176
 Training neural network. Epochs trained: 177
 Training neural network. Epochs trained: 178
 Training neural network. Epochs trained: 179
 Training neural network. Epochs trained: 180
 Training neural network. Epochs trained: 181
 Training neural network. Epochs trained: 182
 Training neural network. Epochs trained: 183
 Training neural network. Epochs trained: 184
 Training neural network. Epochs trained: 185
 Training neural network. Epochs trained: 186
 Training neural network. Epochs trained: 187
 Training neural network. Epochs trained: 188
 Training neural network. Epochs trained: 189
 Training neural network. Epochs trained: 190
 Training neural network. Epochs trained: 191
 Training neural network. Epochs trained: 192
 Neural network successfully converged after 192 epochs.
# sample from SNL posterior
samples = posterior.sample((2000,), x=observation_summary_stats)
snl_data = np.array(samples)
np.savetxt("wide_prior_samples.txt", snl_data)
Running 1 MCMC chains in 1 batches.:   0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]


Tuning bracket width...:   0%|          | 0/50 [00:00<?, ?it/s]


Tuning bracket width...:  20%|██        | 10/50 [00:02<00:09,  4.21it/s]


Tuning bracket width...:  40%|████      | 20/50 [00:03<00:04,  6.67it/s]


Tuning bracket width...:  60%|██████    | 30/50 [00:04<00:02,  8.51it/s]


Tuning bracket width...:  80%|████████  | 40/50 [00:04<00:01,  9.56it/s]


Tuning bracket width...: 100%|██████████| 50/50 [00:05<00:00, 10.48it/s]

Tuning bracket width...: 100%|██████████| 50/50 [00:05<00:00,  8.78it/s]


  0%|          | 0/10 [00:00<?, ?it/s]


Generating samples:   0%|          | 0/10 [00:00<?, ?it/s]


Generating samples: 100%|██████████| 10/10 [00:08<00:00,  1.23it/s]

Generating samples: 100%|██████████| 10/10 [00:08<00:00,  1.23it/s]


  0%|          | 0/2000 [00:00<?, ?it/s]


Generating samples:   0%|          | 0/2000 [00:00<?, ?it/s]


Generating samples:   0%|          | 10/2000 [00:08<28:20,  1.17it/s]


Generating samples:   1%|          | 20/2000 [00:16<27:31,  1.20it/s]


Generating samples:   2%|▏         | 30/2000 [00:25<27:42,  1.19it/s]


Generating samples:   2%|▏         | 40/2000 [00:33<27:08,  1.20it/s]


Generating samples:   2%|▎         | 50/2000 [00:42<27:27,  1.18it/s]


Generating samples:   3%|▎         | 60/2000 [00:50<27:25,  1.18it/s]


Generating samples:   4%|▎         | 70/2000 [00:59<27:09,  1.18it/s]


Generating samples:   4%|▍         | 80/2000 [01:07<27:09,  1.18it/s]


Generating samples:   4%|▍         | 90/2000 [01:16<27:18,  1.17it/s]


Generating samples:   5%|▌         | 100/2000 [01:24<26:57,  1.17it/s]


Generating samples:   6%|▌         | 110/2000 [01:33<26:38,  1.18it/s]


Generating samples:   6%|▌         | 120/2000 [01:41<26:26,  1.18it/s]


Generating samples:   6%|▋         | 130/2000 [01:49<26:10,  1.19it/s]


Generating samples:   7%|▋         | 140/2000 [01:58<26:06,  1.19it/s]


Generating samples:   8%|▊         | 150/2000 [02:06<25:46,  1.20it/s]


Generating samples:   8%|▊         | 160/2000 [02:14<25:31,  1.20it/s]


Generating samples:   8%|▊         | 170/2000 [02:22<25:15,  1.21it/s]


Generating samples:   9%|▉         | 180/2000 [02:31<25:19,  1.20it/s]


Generating samples:  10%|▉         | 190/2000 [02:39<25:11,  1.20it/s]


Generating samples:  10%|█         | 200/2000 [02:47<24:48,  1.21it/s]


Generating samples:  10%|█         | 210/2000 [02:56<24:38,  1.21it/s]


Generating samples:  11%|█         | 220/2000 [03:04<24:33,  1.21it/s]


Generating samples:  12%|█▏        | 230/2000 [03:13<25:00,  1.18it/s]


Generating samples:  12%|█▏        | 240/2000 [03:21<24:38,  1.19it/s]


Generating samples:  12%|█▎        | 250/2000 [03:29<24:19,  1.20it/s]


Generating samples:  13%|█▎        | 260/2000 [03:38<24:23,  1.19it/s]


Generating samples:  14%|█▎        | 270/2000 [03:46<24:13,  1.19it/s]


Generating samples:  14%|█▍        | 280/2000 [03:54<23:55,  1.20it/s]


Generating samples:  14%|█▍        | 290/2000 [04:03<23:40,  1.20it/s]


Generating samples:  15%|█▌        | 300/2000 [04:11<23:17,  1.22it/s]


Generating samples:  16%|█▌        | 310/2000 [04:19<23:12,  1.21it/s]


Generating samples:  16%|█▌        | 320/2000 [04:27<23:19,  1.20it/s]


Generating samples:  16%|█▋        | 330/2000 [04:36<23:17,  1.19it/s]


Generating samples:  17%|█▋        | 340/2000 [04:44<23:01,  1.20it/s]


Generating samples:  18%|█▊        | 350/2000 [04:52<22:45,  1.21it/s]


Generating samples:  18%|█▊        | 360/2000 [05:01<22:42,  1.20it/s]


Generating samples:  18%|█▊        | 370/2000 [05:09<22:29,  1.21it/s]


Generating samples:  19%|█▉        | 380/2000 [05:17<22:13,  1.22it/s]


Generating samples:  20%|█▉        | 390/2000 [05:25<22:01,  1.22it/s]


Generating samples:  20%|██        | 400/2000 [05:34<22:04,  1.21it/s]


Generating samples:  20%|██        | 410/2000 [05:42<21:46,  1.22it/s]


Generating samples:  21%|██        | 420/2000 [05:50<21:38,  1.22it/s]


Generating samples:  22%|██▏       | 430/2000 [05:58<21:26,  1.22it/s]


Generating samples:  22%|██▏       | 440/2000 [06:06<21:22,  1.22it/s]


Generating samples:  22%|██▎       | 450/2000 [06:15<21:19,  1.21it/s]


Generating samples:  23%|██▎       | 460/2000 [06:23<20:56,  1.23it/s]


Generating samples:  24%|██▎       | 470/2000 [06:31<20:43,  1.23it/s]


Generating samples:  24%|██▍       | 480/2000 [06:39<20:42,  1.22it/s]


Generating samples:  24%|██▍       | 490/2000 [06:47<20:36,  1.22it/s]


Generating samples:  25%|██▌       | 500/2000 [06:55<20:27,  1.22it/s]


Generating samples:  26%|██▌       | 510/2000 [07:04<20:19,  1.22it/s]


Generating samples:  26%|██▌       | 520/2000 [07:12<20:11,  1.22it/s]


Generating samples:  26%|██▋       | 530/2000 [07:20<20:07,  1.22it/s]


Generating samples:  27%|██▋       | 540/2000 [07:28<19:57,  1.22it/s]


Generating samples:  28%|██▊       | 550/2000 [07:36<19:54,  1.21it/s]


Generating samples:  28%|██▊       | 560/2000 [07:45<19:48,  1.21it/s]


Generating samples:  28%|██▊       | 570/2000 [07:53<19:35,  1.22it/s]


Generating samples:  29%|██▉       | 580/2000 [08:01<19:18,  1.23it/s]


Generating samples:  30%|██▉       | 590/2000 [08:09<19:11,  1.22it/s]


Generating samples:  30%|███       | 600/2000 [08:17<18:53,  1.23it/s]


Generating samples:  30%|███       | 610/2000 [08:25<18:54,  1.23it/s]


Generating samples:  31%|███       | 620/2000 [08:33<18:35,  1.24it/s]


Generating samples:  32%|███▏      | 630/2000 [08:42<18:34,  1.23it/s]


Generating samples:  32%|███▏      | 640/2000 [08:50<18:27,  1.23it/s]


Generating samples:  32%|███▎      | 650/2000 [08:58<18:26,  1.22it/s]


Generating samples:  33%|███▎      | 660/2000 [09:06<18:17,  1.22it/s]


Generating samples:  34%|███▎      | 670/2000 [09:15<18:21,  1.21it/s]


Generating samples:  34%|███▍      | 680/2000 [09:23<18:13,  1.21it/s]


Generating samples:  34%|███▍      | 690/2000 [09:31<18:12,  1.20it/s]


Generating samples:  35%|███▌      | 700/2000 [09:40<17:55,  1.21it/s]


Generating samples:  36%|███▌      | 710/2000 [09:48<17:46,  1.21it/s]


Generating samples:  36%|███▌      | 720/2000 [09:56<17:36,  1.21it/s]


Generating samples:  36%|███▋      | 730/2000 [10:04<17:23,  1.22it/s]


Generating samples:  37%|███▋      | 740/2000 [10:13<17:27,  1.20it/s]


Generating samples:  38%|███▊      | 750/2000 [10:21<17:17,  1.21it/s]


Generating samples:  38%|███▊      | 760/2000 [10:29<17:12,  1.20it/s]


Generating samples:  38%|███▊      | 770/2000 [10:38<17:04,  1.20it/s]


Generating samples:  39%|███▉      | 780/2000 [10:46<16:50,  1.21it/s]


Generating samples:  40%|███▉      | 790/2000 [10:54<16:44,  1.20it/s]


Generating samples:  40%|████      | 800/2000 [11:02<16:30,  1.21it/s]


Generating samples:  40%|████      | 810/2000 [11:11<16:28,  1.20it/s]


Generating samples:  41%|████      | 820/2000 [11:19<16:08,  1.22it/s]


Generating samples:  42%|████▏     | 830/2000 [11:27<16:02,  1.22it/s]


Generating samples:  42%|████▏     | 840/2000 [11:35<15:50,  1.22it/s]


Generating samples:  42%|████▎     | 850/2000 [11:43<15:40,  1.22it/s]


Generating samples:  43%|████▎     | 860/2000 [11:52<15:35,  1.22it/s]


Generating samples:  44%|████▎     | 870/2000 [12:00<15:30,  1.21it/s]


Generating samples:  44%|████▍     | 880/2000 [12:08<15:20,  1.22it/s]


Generating samples:  44%|████▍     | 890/2000 [12:16<15:03,  1.23it/s]


Generating samples:  45%|████▌     | 900/2000 [12:24<14:52,  1.23it/s]


Generating samples:  46%|████▌     | 910/2000 [12:32<14:48,  1.23it/s]


Generating samples:  46%|████▌     | 920/2000 [12:40<14:36,  1.23it/s]


Generating samples:  46%|████▋     | 930/2000 [12:48<14:25,  1.24it/s]


Generating samples:  47%|████▋     | 940/2000 [12:57<14:21,  1.23it/s]


Generating samples:  48%|████▊     | 950/2000 [13:05<14:26,  1.21it/s]


Generating samples:  48%|████▊     | 960/2000 [13:14<14:26,  1.20it/s]


Generating samples:  48%|████▊     | 970/2000 [13:22<14:16,  1.20it/s]


Generating samples:  49%|████▉     | 980/2000 [13:30<14:07,  1.20it/s]


Generating samples:  50%|████▉     | 990/2000 [13:39<14:02,  1.20it/s]


Generating samples:  50%|█████     | 1000/2000 [13:47<13:47,  1.21it/s]


Generating samples:  50%|█████     | 1010/2000 [13:55<13:33,  1.22it/s]


Generating samples:  51%|█████     | 1020/2000 [14:03<13:24,  1.22it/s]


Generating samples:  52%|█████▏    | 1030/2000 [14:11<13:12,  1.22it/s]


Generating samples:  52%|█████▏    | 1040/2000 [14:19<13:01,  1.23it/s]


Generating samples:  52%|█████▎    | 1050/2000 [14:27<12:52,  1.23it/s]


Generating samples:  53%|█████▎    | 1060/2000 [14:35<12:46,  1.23it/s]


Generating samples:  54%|█████▎    | 1070/2000 [14:44<12:44,  1.22it/s]


Generating samples:  54%|█████▍    | 1080/2000 [14:52<12:38,  1.21it/s]


Generating samples:  55%|█████▍    | 1090/2000 [15:00<12:28,  1.22it/s]


Generating samples:  55%|█████▌    | 1100/2000 [15:09<12:22,  1.21it/s]


Generating samples:  56%|█████▌    | 1110/2000 [15:17<12:15,  1.21it/s]


Generating samples:  56%|█████▌    | 1120/2000 [15:25<12:06,  1.21it/s]


Generating samples:  56%|█████▋    | 1130/2000 [15:33<11:50,  1.22it/s]


Generating samples:  57%|█████▋    | 1140/2000 [15:41<11:35,  1.24it/s]


Generating samples:  57%|█████▊    | 1150/2000 [15:49<11:27,  1.24it/s]


Generating samples:  58%|█████▊    | 1160/2000 [15:58<11:26,  1.22it/s]


Generating samples:  58%|█████▊    | 1170/2000 [16:06<11:17,  1.23it/s]


Generating samples:  59%|█████▉    | 1180/2000 [16:14<11:06,  1.23it/s]


Generating samples:  60%|█████▉    | 1190/2000 [16:22<10:59,  1.23it/s]


Generating samples:  60%|██████    | 1200/2000 [16:30<10:51,  1.23it/s]


Generating samples:  60%|██████    | 1210/2000 [16:38<10:44,  1.23it/s]


Generating samples:  61%|██████    | 1220/2000 [16:46<10:38,  1.22it/s]


Generating samples:  62%|██████▏   | 1230/2000 [16:55<10:31,  1.22it/s]


Generating samples:  62%|██████▏   | 1240/2000 [17:03<10:23,  1.22it/s]


Generating samples:  62%|██████▎   | 1250/2000 [17:11<10:20,  1.21it/s]


Generating samples:  63%|██████▎   | 1260/2000 [17:19<10:08,  1.22it/s]


Generating samples:  64%|██████▎   | 1270/2000 [17:28<10:01,  1.21it/s]


Generating samples:  64%|██████▍   | 1280/2000 [17:36<09:57,  1.21it/s]


Generating samples:  64%|██████▍   | 1290/2000 [17:44<09:49,  1.20it/s]


Generating samples:  65%|██████▌   | 1300/2000 [17:53<09:37,  1.21it/s]


Generating samples:  66%|██████▌   | 1310/2000 [18:01<09:31,  1.21it/s]


Generating samples:  66%|██████▌   | 1320/2000 [18:09<09:25,  1.20it/s]


Generating samples:  66%|██████▋   | 1330/2000 [18:18<09:15,  1.21it/s]


Generating samples:  67%|██████▋   | 1340/2000 [18:26<09:05,  1.21it/s]


Generating samples:  68%|██████▊   | 1350/2000 [18:34<08:59,  1.21it/s]


Generating samples:  68%|██████▊   | 1360/2000 [18:42<08:50,  1.21it/s]


Generating samples:  68%|██████▊   | 1370/2000 [18:51<08:42,  1.21it/s]


Generating samples:  69%|██████▉   | 1380/2000 [18:59<08:40,  1.19it/s]


Generating samples:  70%|██████▉   | 1390/2000 [19:08<08:31,  1.19it/s]


Generating samples:  70%|███████   | 1400/2000 [19:16<08:17,  1.21it/s]


Generating samples:  70%|███████   | 1410/2000 [19:24<08:06,  1.21it/s]


Generating samples:  71%|███████   | 1420/2000 [19:32<07:57,  1.21it/s]


Generating samples:  72%|███████▏  | 1430/2000 [19:40<07:48,  1.22it/s]


Generating samples:  72%|███████▏  | 1440/2000 [19:48<07:36,  1.23it/s]


Generating samples:  72%|███████▎  | 1450/2000 [19:56<07:26,  1.23it/s]


Generating samples:  73%|███████▎  | 1460/2000 [20:05<07:25,  1.21it/s]


Generating samples:  74%|███████▎  | 1470/2000 [20:13<07:20,  1.20it/s]


Generating samples:  74%|███████▍  | 1480/2000 [20:22<07:13,  1.20it/s]


Generating samples:  74%|███████▍  | 1490/2000 [20:30<06:59,  1.22it/s]


Generating samples:  75%|███████▌  | 1500/2000 [20:38<06:49,  1.22it/s]


Generating samples:  76%|███████▌  | 1510/2000 [20:46<06:46,  1.21it/s]


Generating samples:  76%|███████▌  | 1520/2000 [20:54<06:34,  1.22it/s]


Generating samples:  76%|███████▋  | 1530/2000 [21:03<06:24,  1.22it/s]


Generating samples:  77%|███████▋  | 1540/2000 [21:11<06:18,  1.22it/s]


Generating samples:  78%|███████▊  | 1550/2000 [21:19<06:11,  1.21it/s]


Generating samples:  78%|███████▊  | 1560/2000 [21:27<06:02,  1.21it/s]


Generating samples:  78%|███████▊  | 1570/2000 [21:36<05:57,  1.20it/s]


Generating samples:  79%|███████▉  | 1580/2000 [21:44<05:47,  1.21it/s]


Generating samples:  80%|███████▉  | 1590/2000 [21:52<05:39,  1.21it/s]


Generating samples:  80%|████████  | 1600/2000 [22:01<05:30,  1.21it/s]


Generating samples:  80%|████████  | 1610/2000 [22:09<05:24,  1.20it/s]


Generating samples:  81%|████████  | 1620/2000 [22:17<05:14,  1.21it/s]


Generating samples:  82%|████████▏ | 1630/2000 [22:26<05:06,  1.21it/s]


Generating samples:  82%|████████▏ | 1640/2000 [22:34<04:59,  1.20it/s]


Generating samples:  82%|████████▎ | 1650/2000 [22:42<04:51,  1.20it/s]


Generating samples:  83%|████████▎ | 1660/2000 [22:51<04:43,  1.20it/s]


Generating samples:  84%|████████▎ | 1670/2000 [22:59<04:36,  1.19it/s]


Generating samples:  84%|████████▍ | 1680/2000 [23:07<04:27,  1.19it/s]


Generating samples:  84%|████████▍ | 1690/2000 [23:16<04:19,  1.19it/s]


Generating samples:  85%|████████▌ | 1700/2000 [23:24<04:09,  1.20it/s]


Generating samples:  86%|████████▌ | 1710/2000 [23:32<03:57,  1.22it/s]


Generating samples:  86%|████████▌ | 1720/2000 [23:40<03:51,  1.21it/s]


Generating samples:  86%|████████▋ | 1730/2000 [23:49<03:42,  1.21it/s]


Generating samples:  87%|████████▋ | 1740/2000 [23:57<03:34,  1.21it/s]


Generating samples:  88%|████████▊ | 1750/2000 [24:05<03:22,  1.23it/s]


Generating samples:  88%|████████▊ | 1760/2000 [24:13<03:15,  1.23it/s]


Generating samples:  88%|████████▊ | 1770/2000 [24:21<03:08,  1.22it/s]


Generating samples:  89%|████████▉ | 1780/2000 [24:30<03:01,  1.21it/s]


Generating samples:  90%|████████▉ | 1790/2000 [24:38<02:52,  1.21it/s]


Generating samples:  90%|█████████ | 1800/2000 [24:46<02:44,  1.21it/s]


Generating samples:  90%|█████████ | 1810/2000 [24:54<02:34,  1.23it/s]


Generating samples:  91%|█████████ | 1820/2000 [25:02<02:25,  1.24it/s]


Generating samples:  92%|█████████▏| 1830/2000 [25:10<02:19,  1.22it/s]


Generating samples:  92%|█████████▏| 1840/2000 [25:19<02:12,  1.21it/s]


Generating samples:  92%|█████████▎| 1850/2000 [25:27<02:04,  1.21it/s]


Generating samples:  93%|█████████▎| 1860/2000 [25:35<01:56,  1.21it/s]


Generating samples:  94%|█████████▎| 1870/2000 [25:44<01:48,  1.20it/s]


Generating samples:  94%|█████████▍| 1880/2000 [25:52<01:40,  1.19it/s]


Generating samples:  94%|█████████▍| 1890/2000 [26:01<01:32,  1.19it/s]


Generating samples:  95%|█████████▌| 1900/2000 [26:09<01:24,  1.19it/s]


Generating samples:  96%|█████████▌| 1910/2000 [26:18<01:15,  1.19it/s]


Generating samples:  96%|█████████▌| 1920/2000 [26:26<01:06,  1.20it/s]


Generating samples:  96%|█████████▋| 1930/2000 [26:34<00:58,  1.21it/s]


Generating samples:  97%|█████████▋| 1940/2000 [26:43<00:50,  1.19it/s]


Generating samples:  98%|█████████▊| 1950/2000 [26:51<00:42,  1.19it/s]


Generating samples:  98%|█████████▊| 1960/2000 [26:59<00:33,  1.20it/s]


Generating samples:  98%|█████████▊| 1970/2000 [27:08<00:24,  1.20it/s]


Generating samples:  99%|█████████▉| 1980/2000 [27:16<00:16,  1.20it/s]


Generating samples: 100%|█████████▉| 1990/2000 [27:24<00:08,  1.20it/s]


Generating samples: 100%|██████████| 2000/2000 [27:33<00:00,  1.19it/s]

Generating samples: 100%|██████████| 2000/2000 [27:33<00:00,  1.21it/s]
Running 1 MCMC chains in 1 batches.: 100%|██████████| 1/1 [27:47<00:00, 1667.10s/it]
Running 1 MCMC chains in 1 batches.: 100%|██████████| 1/1 [27:47<00:00, 1667.11s/it]

# set up restraints for maxent
# restraint structure: [value, uncertainty, indices... ]
restraints = []
for i, point in enumerate(observed_points):
    value1 = point[0]
    value2 = point[1]
    uncertainty = 25
    index = 20 * i + 19  # based on how we slice in get_observation_points()
    restraints.append([value1, uncertainty, index, 0])
    restraints.append([value2, uncertainty, index, 1])
# set up maxent restraints
maxent_restraints = []

for i in range(len(restraints)):
    traj_index = tuple(restraints[i][2:])
    value = restraints[i][0]
    uncertainty = restraints[i][1]
    p = maxent.EmptyPrior()
    r = maxent.Restraint(lambda traj, i=traj_index: traj[i], value, p)
    maxent_restraints.append(r)
# sample from prior for maxent
if os.path.exists("maxent_prior_samples.npy"):
    prior_dist = np.load("maxent_prior_samples.npy")
else:
    prior_dist = np.random.multivariate_normal(prior_means, prior_cov, size=2048)
    np.save("maxent_prior_samples.npy", prior_dist)
# generate trajectories for maxent from prior samples
trajs = np.zeros([prior_dist.shape[0], 100, 2])

for i, sample in enumerate(prior_dist):
    m1, m2, m3, v0 = sample[0], sample[1], sample[2], sample[3:]
    sim = GravitySimulator(m1, m2, m3, v0, random_noise=False)
    traj = sim.run()
    trajs[i] = traj

maxent_trajs = trajs
np.save("maxent_raw_trajectories.npy", trajs)
# run maxent on trajectories
batch_size = prior_dist.shape[0]

model = maxent.MaxentModel(maxent_restraints)
model.compile(tf.keras.optimizers.Adam(1e-4), "mean_squared_error")
# short burn-in
h = model.fit(trajs, batch_size=batch_size, epochs=5000, verbose=0)
# restart to reset learning rate
h = model.fit(trajs, batch_size=batch_size, epochs=25000, verbose=0)

np.savetxt("maxent_loss.txt", h.history["loss"])

maxent_weights = model.traj_weights
np.savetxt("maxent_traj_weights.txt", maxent_weights)

maxent_avg_traj = np.sum(trajs * maxent_weights[:, np.newaxis, np.newaxis], axis=0)
np.savetxt("maxent_avg_traj.txt", maxent_avg_traj)

Plotting Results

# simulate traj generated by prior means
sim = GravitySimulator(prior_means[0], prior_means[1], prior_means[2], prior_means[3:])
prior_means_traj = sim.run()
# simulate trajectories from SNL samples
snl_trajs = np.zeros([snl_data.shape[0], noisy_traj.shape[0], noisy_traj.shape[1]])
for i, sample in enumerate(snl_data):
    m1, m2, m3, v0 = sample[0], sample[1], sample[2], [sample[3], sample[4]]
    sim = GravitySimulator(m1, m2, m3, v0)
    traj = sim.run()
    snl_trajs[i] = traj

mean_snl_traj = np.mean(snl_trajs, axis=0)
np.savetxt("mean_snl_traj.txt", mean_snl_traj)
alpha_val = 0.7
fig, axes = plt.subplots(figsize=(5, 3), dpi=300)

# plot the observation points
axes.scatter(
    observed_points[:, 0],
    observed_points[:, 1],
    color="black",
    zorder=10,
    marker="*",
    label="Observed Points",
)

# plot the trajectory generated by prior means
sim.set_traj(prior_means_traj)
sim.plot_traj(
    fig=fig,
    axes=axes,
    make_colorbar=False,
    save=False,
    cmap=plt.get_cmap("Greys").reversed(),
    color="grey",
    fade_lines=False,
    alpha=alpha_val,
    linestyle="-.",
    linewidth=1,
    label="Prior Mean",
)

# plot the SNL mean trajectory
sim.set_traj(mean_snl_traj)
sim.plot_traj(
    fig=fig,
    axes=axes,
    make_colorbar=False,
    save=False,
    cmap=plt.get_cmap("Greens").reversed(),
    color=colors[0],
    fade_lines=False,
    linewidth=1,
    alpha=alpha_val,
    label="SNL",
)

# plot the true trajectory
sim.set_traj(true_traj)
sim.plot_traj(
    fig=fig,
    axes=axes,
    make_colorbar=False,
    save=False,
    cmap=plt.get_cmap("Reds").reversed(),
    color="black",
    fade_lines=False,
    alpha=alpha_val,
    linestyle=":",
    linewidth=1,
    label="True Path",
    label_attractors=False,
)

# plot the maxent average trajectory
sim.set_traj(maxent_avg_traj)
sim.plot_traj(
    fig=fig,
    axes=axes,
    make_colorbar=False,
    save=False,
    cmap=plt.get_cmap("Oranges").reversed(),
    color=colors[2],
    fade_lines=False,
    alpha=alpha_val,
    linestyle="-",
    linewidth=1,
    label="MaxEnt",
    label_attractors=True,
)

# set limits manually
axes.set_xlim(-5, 130)
axes.set_ylim(-30, 75)

plt.legend(loc="upper left", bbox_to_anchor=(1.05, 1.0))
plt.tight_layout()

# plt.savefig('paths_compare.png')
# plt.savefig('paths_compare.svg')
plt.show()
../_images/gravitation_14_0.png
# set up KDE plotting of posteriors
column_names = ["m1", "m2", "m3", "v0x", "v0y"]

snl_dist = np.array(snl_data)
snl_frame = pd.DataFrame(snl_dist, columns=column_names)

maxent_dist = np.load("maxent_prior_samples.npy")
maxent_frame = pd.DataFrame(maxent_dist, columns=column_names)

fig, axes = plt.subplots(nrows=5, ncols=1, figsize=(5, 5), dpi=300, sharex=False)

# iterate over the five parameters
n_bins = 30
for i, key in enumerate(column_names):
    sns.histplot(
        data=snl_frame,
        x=key,
        ax=axes[i],
        color=colors[0],
        stat="probability",
        element="step",
        kde=True,
        fill=False,
        bins=n_bins,
        lw=1.0,
    )
    sns.histplot(
        data=maxent_frame,
        x=key,
        ax=axes[i],
        color=colors[2],
        stat="probability",
        element="step",
        kde=True,
        fill=False,
        bins=n_bins,
        weights=maxent_weights,
        lw=1.0,
    )
    sns.histplot(
        data=maxent_frame,
        x=key,
        ax=axes[i],
        color=colors[3],
        stat="probability",
        element="step",
        kde=True,
        fill=False,
        bins=n_bins,
        lw=1.0,
    )
    axes[i].axvline(prior_means[i], ls="-.", color="grey", lw=1.2)
    axes[i].axvline(true_params[i], ls=":", color="black", lw=1.2)
    axes[i].set_xlabel(key)

# custom lines object for making legend
custom_lines = [
    Line2D([0], [0], color=colors[3], lw=2),
    Line2D([0], [0], color=colors[0], lw=2),
    Line2D([0], [0], color=colors[2], lw=2),
    Line2D([0], [0], color="black", ls=":", lw=2),
    Line2D([0], [0], color="grey", ls="-.", lw=2),
]
axes[0].legend(
    custom_lines,
    ["Prior", "SNL", "MaxEnt", "True Parameters", "Prior Mean"],
    loc="upper left",
    bbox_to_anchor=(1.05, 1.0),
)
plt.tight_layout()

# plt.savefig('posterior_compare.png')
# plt.savefig('posterior_compare.svg')
plt.show()
../_images/gravitation_15_0.png
# calculating cross-entropy values
def get_crossent(
    prior_samples,
    posterior_samples,
    epsilon=1e-7,
    x_range=[-100, 100],
    nbins=40,
    post_weights=None,
):
    prior_dists = []
    posterior_dists = []
    crossents = []
    for i in range(5):
        prior_dist, _ = np.histogram(
            prior_samples[:, i], bins=nbins, range=x_range, density=True
        )
        prior_dists.append(prior_dist)
        posterior_dist, _ = np.histogram(
            posterior_samples[:, i],
            bins=nbins,
            range=x_range,
            density=True,
            weights=post_weights,
        )
        posterior_dists.append(posterior_dist)
        crossents.append(np.log(posterior_dist + epsilon) * (prior_dist + epsilon))
    return -np.sum(crossents)


snl_prior = np.random.multivariate_normal(
    mean=prior_means, cov=np.eye(5) * 50, size=snl_dist.shape[0]
)
snl_crossent = get_crossent(snl_prior, snl_dist)

maxent_prior = np.random.multivariate_normal(prior_means, np.eye(5) * 50, size=2048)
maxent_crossent = get_crossent(maxent_prior, maxent_prior, post_weights=maxent_weights)

print(f"CROSS-ENTROPY:\nSNL: {snl_crossent}\nMaxEnt: {maxent_crossent}")

crossent_values = [snl_crossent, maxent_crossent]
np.savetxt("crossent_values.txt", np.array(crossent_values), header="SNL, MaxEnt")
CROSS-ENTROPY:
SNL: 5.801560245778351
MaxEnt: 3.410306612211575

MaxEnt With Variational

import tensorflow_probability as tfp

tfd = tfp.distributions
x = np.array(prior_means, dtype=np.float32)
y = np.array(prior_cov, dtype=np.float32)
i = tf.keras.Input((100, 2))
l = maxent.TrainableInputLayer(x)(i)
d = tfp.layers.DistributionLambda(
    lambda x: tfd.MultivariateNormalFullCovariance(loc=x, covariance_matrix=y)
)(l)
model = maxent.ParameterJoint([lambda x: x], inputs=i, outputs=[d])
model.compile(tf.keras.optimizers.SGD(1e-3))
model.summary()
model(tf.constant([1.0, 2.0, 3.0, 4.0, 5.0]))
WARNING:tensorflow:From /opt/hostedtoolcache/Python/3.8.12/x64/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py:342: MultivariateNormalFullCovariance.__init__ (from tensorflow_probability.python.distributions.mvn_full_covariance) is deprecated and will be removed after 2019-12-01.
Instructions for updating:
`MultivariateNormalFullCovariance` is deprecated, use `MultivariateNormalTriL(loc=loc, scale_tril=tf.linalg.cholesky(covariance_matrix))` instead.
Model: "parameter_joint"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 100, 2)]          0         
                                                                 
 trainable_input_layer (Trai  (None, 5)                5         
 nableInputLayer)                                                
                                                                 
 distribution_lambda (Distri  ((None, 5),              0         
 butionLambda)                (None, 5))                         
                                                                 
=================================================================
Total params: 5
Trainable params: 5
Non-trainable params: 0
_________________________________________________________________
<tfp.distributions._TensorCoercible 'tensor_coercible' batch_shape=[5] event_shape=[5] dtype=float32>
def simulate(x, nsteps=100):
    """params_list should be: m1, m2, m3, v0[0], v0[1] in that order"""
    # double nsteps b/c we flatten the (x,y) coordinates
    output = np.zeros((x.shape[0], nsteps, 2))
    for i in range(x.shape[0]):
        params_list = x[i, 0, :]
        m1, m2, m3 = float(params_list[0]), float(params_list[1]), float(params_list[2])
        v0 = np.array([params_list[3], params_list[4]], dtype=np.float64)
        this_sim = GravitySimulator(m1, m2, m3, v0, random_noise=False, nsteps=nsteps)
        # set to 1D to make hypermaxent setup easier
        this_traj = this_sim.run()  # .flatten()
        output[i] = this_traj
    return output
def get_observation_points_from_flat(flat_traj):
    recovered_traj = flat_traj.reshape([-1, 2])
    return get_observation_points(recovered_traj)  # .flatten()
r = []
true_points = get_observation_points(noisy_traj)
true_points_flat = true_points.flatten()
for i, point in enumerate(true_points_flat):
    r.append(
        maxent.Restraint(
            lambda x: get_observation_points_from_flat(x)[i], point, maxent.EmptyPrior()
        )
    )
hme_model = maxent.HyperMaxentModel(maxent_restraints, model, simulate)
hme_model.compile(tf.keras.optimizers.Adam(1e-4), "mean_squared_error")
hme_results = hme_model.fit(
    epochs=30000, sample_batch_size=2048 // 4, outter_epochs=4, verbose=0
)  # one-quarter of plain maxent batch size
WARNING:tensorflow:Gradients do not exist for variables ['value:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss`argument?
WARNING:tensorflow:Gradients do not exist for variables ['value:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss`argument?
WARNING:tensorflow:Model was constructed with shape (None, 100, 2) for input KerasTensor(type_spec=TensorSpec(shape=(None, 100, 2), dtype=tf.float32, name='input_1'), name='input_1', description="created by layer 'input_1'"), but it was called on an input with incompatible shape (32,).
WARNING:tensorflow:Model was constructed with shape (None, 100, 2) for input KerasTensor(type_spec=TensorSpec(shape=(None, 100, 2), dtype=tf.float32, name='input_1'), name='input_1', description="created by layer 'input_1'"), but it was called on an input with incompatible shape (32,).
WARNING:tensorflow:Gradients do not exist for variables ['value:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss`argument?
hme_predicted_params = hme_model.weights[1]
hme_trajectory_weights = hme_model.traj_weights
variational_trajs = hme_model.trajs.reshape([hme_model.trajs.shape[0], -1, 2])
maxent_variational_avg_traj = np.sum(
    variational_trajs * hme_trajectory_weights[:, np.newaxis, np.newaxis], axis=0
)
np.savetxt("maxent_variational_avg_traj.txt", maxent_variational_avg_traj)
# simulate traj generated by prior means
sim = GravitySimulator(prior_means[0], prior_means[1], prior_means[2], prior_means[3:])
prior_means_traj = sim.run()
mean_snl_traj = np.genfromtxt("mean_snl_traj.txt")
maxent_avg_traj = np.genfromtxt("maxent_avg_traj.txt")
maxent_variational_avg_traj = np.genfromtxt("maxent_variational_avg_traj.txt")
alpha_val = 0.7
fig, axes = plt.subplots(figsize=(5, 3), dpi=300)

# plot the observation points
axes.scatter(
    observed_points[:, 0],
    observed_points[:, 1],
    color="black",
    zorder=10,
    marker="*",
    label="Observed Points",
)

# plot the trajectory generated by prior means
sim.set_traj(prior_means_traj)
sim.plot_traj(
    fig=fig,
    axes=axes,
    make_colorbar=False,
    save=False,
    cmap=plt.get_cmap("Greys").reversed(),
    color="grey",
    fade_lines=False,
    alpha=alpha_val,
    linestyle="-.",
    linewidth=1,
    label="Prior Mean",
)

# plot the SNL mean trajectory
sim.set_traj(mean_snl_traj)
sim.plot_traj(
    fig=fig,
    axes=axes,
    make_colorbar=False,
    save=False,
    cmap=plt.get_cmap("Greens").reversed(),
    color=colors[0],
    fade_lines=False,
    linewidth=1,
    alpha=alpha_val,
    label="SNL",
)

# plot the true trajectory
sim.set_traj(true_traj)
sim.plot_traj(
    fig=fig,
    axes=axes,
    make_colorbar=False,
    save=False,
    cmap=plt.get_cmap("Reds").reversed(),
    color="black",
    fade_lines=False,
    alpha=alpha_val,
    linestyle=":",
    linewidth=1,
    label="True Path",
    label_attractors=False,
)

# plot the maxent average trajectory
sim.set_traj(maxent_avg_traj)
sim.plot_traj(
    fig=fig,
    axes=axes,
    make_colorbar=False,
    save=False,
    cmap=plt.get_cmap("Oranges").reversed(),
    color=colors[2],
    fade_lines=False,
    alpha=alpha_val,
    linestyle="-",
    linewidth=1,
    label="MaxEnt",
    label_attractors=False,
)

# plot the maxent average trajectory
sim.set_traj(maxent_variational_avg_traj)
sim.plot_traj(
    fig=fig,
    axes=axes,
    make_colorbar=False,
    save=False,
    cmap=plt.get_cmap("Oranges").reversed(),
    color=colors[3],
    fade_lines=False,
    alpha=alpha_val,
    linestyle="-",
    linewidth=1,
    label="Variational MaxEnt",
    label_attractors=True,
)

# set limits manually
axes.set_xlim(-5, 130)
axes.set_ylim(-30, 75)

plt.legend(loc="upper left", bbox_to_anchor=(1.05, 1.0))
plt.tight_layout()

plt.savefig("paths_compare.png")
plt.savefig("paths_compare.svg")
plt.show()
../_images/gravitation_27_0.png