Dynamiqs: Experimental data fitting background cover

Dynamiqs: Experimental data fitting

Fitting a cat qubit experiment with real-world data

Alice & Bob

Hosted by

Alice & Bob

Khalamendyk Ivan
Lubitelua

1

Posted

Ψ-PINN: What if space = wavelength and time = frequency?

А що якщо ми будемо вважати, що простір — це довжина хвилі, а час — це частота? Тоді замість того, щоб ганяти класичний Lindblad master equation, ми можемо спробувати описати динаміку cat-qubit через хвильову функцію ψ(x,t), яка живе на цій «довжина–частота» сітці. 👉 Ідея така:
  • Замість щільнісної матриці ρ ми вчимо нейронну мережу (PINN), яка напряму апроксимує ψ(x,t).
  • Loss має 3 частини:
    1. Data-fit: узгодження з експериментальними траєкторіями (short.npy).
    2. Physics: штраф за порушення рівняння руху □ψ+ψ+λψ3=0,λ=32π,\square \psi + \psi + \lambda \psi^3 = 0, \quad \lambda = \tfrac{3}{2\pi},□ψ+ψ+λψ3=0,λ=2π3​, плюс дампінг та two-photon терміни.
    3. Topological: штраф за розбіжність ∇φ (градієнт фази) між ψ та даними. Таким чином, ми отримуємо гібрид квантової симуляції та топологічної хвильової моделі: PINN «вчиться» не тільки підганяти дані, але й зберігати топологічні інваріанти ψ-поля. Це принципово інший підхід: ми відмовляємось від прямої еволюції ρ(t) через mesolve і вчимо геометрію хвильового вузла.

Order by:

Khalamendyk Ivan

1

Posted by Lubitelua

--- Ψ-PINN: physics + data + topology ---

import jax, jax.numpy as jnp, optax from jax import grad, random, vmap import numpy as np

---------- Load experiment data ----------

def _load_any(name): import pathlib for p in [f".aqora/data/data/{name}", f".aqora/data/{name}", f"data/{name}", name]: if pathlib.Path(p).exists(): return np.load(p, allow_pickle=True) raise FileNotFoundError(name)
data = jnp.array(_load_any("short.npy"), dtype=jnp.float64) # [T,B] tsave = jnp.array(_load_any("time_short.npy"), dtype=jnp.float64) # [T] alphas = jnp.sqrt(data[0]) # coherent amplitudes
T, B = data.shape[0], data.shape[1]

---------- Simple MLP (PINN) ----------

def mlp(params, x): for w, b in params: x = jnp.tanh(jnp.dot(x, w) + b) return x
def init_mlp(key, layers=[2, 32, 32, 1]): keys = random.split(key, len(layers)-1) params = [] for kin, kout, k in zip(layers[:-1], layers[1:], keys): w = random.normal(k, (kin, kout)) / jnp.sqrt(kin) b = jnp.zeros(kout) params.append((w, b)) return params

---------- ψ-PINN forward ----------

def psi_pinn(params, x, t, alpha): inp = jnp.array([x, t]) return alpha * mlp(params, inp)

---------- Physics residual ----------

lambda_ = 3 / (2 * jnp.pi)
def physics_residual(params, x, t, alpha, u): psi_fn = lambda xt: psi_pinn(params, xt[0], xt[1], alpha) psi = psi_fn((x, t)) dpsi_dt = grad(psi_fn, argnums=1)((x, t)) d2psi_dx2 = grad(grad(psi_fn, argnums=0), argnums=0)((x, t)) g2, kappa_a, kappa_b, nth_a, nth_b = u res = d2psi_dx2 + dpsi_dt + psi + lambda_ * psi3 res += -kappa_a psi - kappa_b(psi - nth_bjnp.mean(psi)) - 1jg2*psi2 return jnp.real(res)**2


Khalamendyk Ivan

1

Posted by Lubitelua

---------- Loss function ----------

w_t = jnp.exp(-tsave/0.5); w_t /= w_t.sum()
x_grid = jnp.linspace(-5, 5, 32)
def loss_fn(params, u):
data_loss, phys_loss, topo_loss = 0.0, 0.0, 0.0

for b in range(B):

    alpha = alphas[b]

    sim = []

    for t in tsave:

        psi_vals = vmap(lambda x: psi_pinn(params, x, t, alpha))(x_grid)

        sim.append(jnp.mean(jnp.abs(psi_vals)**2))

    sim = jnp.array(sim)

    err = sim - data[:, b]

    data_loss += jnp.mean(w_t * err**2)

    # physics + topo penalties

    phys_loss += jnp.mean([physics_residual(params, x, t, alpha, u) for x in x_grid for t in tsave])

    topo_loss += jnp.mean(jnp.gradient(sim) - jnp.gradient(data[:, b]))

return data_loss + 0.1*phys_loss + 0.01*topo_loss

---------- Training ----------

def run_opt(params, u, steps=200, lr=1e-3):
opt = optax.adam(lr)

state = opt.init((params,u))

best_val, best = 1e9, (params,u)

for i in range(steps):

    val, grads = jax.value_and_grad(loss_fn, argnums=(0,1))(params, u)

    updates, state = opt.update(grads, state)

    params, u = optax.apply_updates((params,u), updates)

    if val < best_val:

        best_val, best = val, (params,u)

return best

---------- Example run ----------

key = random.PRNGKey(0)
params = init_mlp(key)
u0 = jnp.array([0.1, 0.01, 0.05, 1e-3, 1e-3])
best_params, best_u = run_opt(params, u0)
print("Ψ-PINN best params:", best_u)

Want to join this discussion?

Join our community today and start discussing with our members by participating in exciting events, competitions, and challenges. Sign up now to engage with quantum experts!