Hosted by
Alice & Bob
1
Posted by Lubitelua •
1
Posted by Lubitelua •
data_loss, phys_loss, topo_loss = 0.0, 0.0, 0.0
for b in range(B):
alpha = alphas[b]
sim = []
for t in tsave:
psi_vals = vmap(lambda x: psi_pinn(params, x, t, alpha))(x_grid)
sim.append(jnp.mean(jnp.abs(psi_vals)**2))
sim = jnp.array(sim)
err = sim - data[:, b]
data_loss += jnp.mean(w_t * err**2)
# physics + topo penalties
phys_loss += jnp.mean([physics_residual(params, x, t, alpha, u) for x in x_grid for t in tsave])
topo_loss += jnp.mean(jnp.gradient(sim) - jnp.gradient(data[:, b]))
return data_loss + 0.1*phys_loss + 0.01*topo_loss
opt = optax.adam(lr)
state = opt.init((params,u))
best_val, best = 1e9, (params,u)
for i in range(steps):
val, grads = jax.value_and_grad(loss_fn, argnums=(0,1))(params, u)
updates, state = opt.update(grads, state)
params, u = optax.apply_updates((params,u), updates)
if val < best_val:
best_val, best = val, (params,u)
return best