Karma 0
Posted by Lubitelua •
Posted by Lubitelua •
for w, b in params[:-1]:
inputs = jnp.tanh(inputs @ w + b)
w, b = params[-1]
return inputs @ w + b # [...,1]
keys = random.split(key, len(layers)-1)
params = []
for kin, kout, k in zip(layers[:-1], layers[1:], keys):
w = random.normal(k, (kin, kout)) / jnp.sqrt(kin)
b = jnp.zeros(kout)
params.append((w, b))
return params
inp = jnp.concatenate([jnp.array([x, t]), ctrl])
return mlp(params, inp[None, :])[0]
psi_fn = lambda x: psi_pinn(params, x, t, ctrl)
psi = vmap(psi_fn)(x_grid)
dpsi_dt = grad(lambda tt: psi_pinn(params, x_grid[0], tt, ctrl))(t)
d2psi_dx2 = grad(grad(lambda xx: psi_pinn(params, xx, t, ctrl)))(x_grid[0])
g2, kappa = u
res = d2psi_dx2 + dpsi_dt + psi + λ*psi**3 + g2*psi*ctrl[0] - kappa*psi
return jnp.mean(res**2)
Posted by Lubitelua •
data_loss, phys_loss, topo_loss = 0.0, 0.0, 0.0
for b in range(B):
alpha = alphas[b]
sim = []
for t in tsave:
psi_vals = vmap(lambda x: psi_pinn(params, x, t, alpha))(x_grid)
sim.append(jnp.mean(jnp.abs(psi_vals)**2))
sim = jnp.array(sim)
err = sim - data[:, b]
data_loss += jnp.mean(w_t * err**2)
# physics + topo penalties
phys_loss += jnp.mean([physics_residual(params, x, t, alpha, u) for x in x_grid for t in tsave])
topo_loss += jnp.mean(jnp.gradient(sim) - jnp.gradient(data[:, b]))
return data_loss + 0.1*phys_loss + 0.01*topo_loss
opt = optax.adam(lr)
state = opt.init((params,u))
best_val, best = 1e9, (params,u)
for i in range(steps):
val, grads = jax.value_and_grad(loss_fn, argnums=(0,1))(params, u)
updates, state = opt.update(grads, state)
params, u = optax.apply_updates((params,u), updates)
if val < best_val:
best_val, best = val, (params,u)
return best
Posted by Lubitelua •