Hosted by
Alice & Bob
1
Posted by Lubitelua •
for w, b in params[:-1]:
inputs = jnp.tanh(inputs @ w + b)
w, b = params[-1]
return inputs @ w + b # [...,1]
keys = random.split(key, len(layers)-1)
params = []
for kin, kout, k in zip(layers[:-1], layers[1:], keys):
w = random.normal(k, (kin, kout)) / jnp.sqrt(kin)
b = jnp.zeros(kout)
params.append((w, b))
return params
inp = jnp.concatenate([jnp.array([x, t]), ctrl])
return mlp(params, inp[None, :])[0]
psi_fn = lambda x: psi_pinn(params, x, t, ctrl)
psi = vmap(psi_fn)(x_grid)
dpsi_dt = grad(lambda tt: psi_pinn(params, x_grid[0], tt, ctrl))(t)
d2psi_dx2 = grad(grad(lambda xx: psi_pinn(params, xx, t, ctrl)))(x_grid[0])
g2, kappa = u
res = d2psi_dx2 + dpsi_dt + psi + λ*psi**3 + g2*psi*ctrl[0] - kappa*psi
return jnp.mean(res**2)
1
Posted by Lubitelua •