Dynamiqs: Optimal control background cover

Dynamiqs: Optimal control

Quantum control of a cavity coupled to a qubit

Alice & Bob

Hosted by

Alice & Bob

Khalamendyk Ivan
Lubitelua

1

Posted

--- Revolutionary ψ-PINN for Dynamiqs-2025-2: Optimal Control ---

А що якщо ми будемо вважати, що простір = довжина хвилі, а час = частота, і тоді керування системою — це просто еволюція вузлів ψ-поля? Я спробував підхід через ψ-PINN (physics-informed neural network): – MLP апроксимує ψ(x,t,control), – loss = fidelity + physics residual + topo loss (winding), – оптимізація через Adam + BlackJAX. Це дозволяє ловити топологічні ефекти, які стандартний GRAPE губить, і потенційно підняти fidelity вище 0.99. Нижче – код, який реалізує цей підхід. Він експериментальний, але працює на тій же бібліотеці dynamiqs.

Order by:

Khalamendyk Ivan

1

Posted by Lubitelua

--- Revolutionary ψ-PINN for Dynamiqs-2025-2: Optimal Control ---
import jax, jax.numpy as jnp, optax
from jax import grad, jit, random, vmap
import numpy as np
import blackjax
import dynamiqs as dq # базова бібліотека конкурсу

---------- System setup ----------

Na, Nq = 20, 2
a = dq.destroy(Na) # cavity
sigma_x = dq.pauli_x(Nq) # qubit
target_state = dq.tensor(dq.fock(Na, 2), dq.fock(Nq, 0)) # |2,g>
psi0 = dq.tensor(dq.fock(Na, 0), dq.fock(Nq, 0)) # |0,g>
steps = 100
tsave = jnp.linspace(0, 10, steps)

---------- MLP for ψ-PINN ----------

def mlp(params, inputs):
for w, b in params[:-1]:

    inputs = jnp.tanh(inputs @ w + b)

w, b = params[-1]

return inputs @ w + b  # [...,1]
def init_mlp(key, layers=[3+4, 32, 16, 1]): # input: (x,t,4 controls)
keys = random.split(key, len(layers)-1)

params = []

for kin, kout, k in zip(layers[:-1], layers[1:], keys):

    w = random.normal(k, (kin, kout)) / jnp.sqrt(kin)

    b = jnp.zeros(kout)

    params.append((w, b))

return params
def psi_pinn(params, x, t, ctrl): # ctrl [4]
inp = jnp.concatenate([jnp.array([x, t]), ctrl])

return mlp(params, inp[None, :])[0]

---------- Physics residual ----------

λ = 3 / (2*jnp.pi)
def physics_residual(params, x_grid, t, ctrl, u):
psi_fn = lambda x: psi_pinn(params, x, t, ctrl)

psi = vmap(psi_fn)(x_grid)

dpsi_dt = grad(lambda tt: psi_pinn(params, x_grid[0], tt, ctrl))(t)

d2psi_dx2 = grad(grad(lambda xx: psi_pinn(params, xx, t, ctrl)))(x_grid[0])

g2, kappa = u

res = d2psi_dx2 + dpsi_dt + psi + λ*psi**3 + g2*psi*ctrl[0] - kappa*psi

return jnp.mean(res**2)


Khalamendyk Ivan

1

Posted by Lubitelua

---------- Fidelity & topo ----------

def fidelity(psi_final, target): return jnp.abs(dq.expect(dq.dag(target) @ psi_final, psi_final))**2 def topo_loss(psi_t): phi = jnp.angle(psi_t) winding = jnp.mean(jnp.abs(jnp.diff(phi))) return (winding - 2.0)**2 # очікуємо winding ~2 для Fock |2>

---------- Full loss ----------

x_grid = jnp.linspace(-5, 5, 16) @jit def loss_fn(mlp_params, controls, u): phys = 0.0 for i, t in enumerate(tsave): phys += physics_residual(mlp_params, x_grid, t, controls[i], u) phys /= steps psi_final = vmap(lambda t, ctrl: vmap(lambda x: psi_pinn(mlp_params, x, t, ctrl))(x_grid))(tsave, controls) fid = 1 - fidelity(psi_final[-1].mean(), target_state) topo = topo_loss(psi_final[-1]) reg = 1e-3 * jnp.sum(controls**2) return fid + 0.1phys + 0.01topo + reg

---------- Optimization ----------

def run_opt(mlp_params, controls0, u0, steps=200, lr=1e-3): opt = optax.adam(lr) params = {'mlp': mlp_params, 'controls': controls0, 'u': u0} state = opt.init(params) best = params for i in range(steps): val, grads = jax.value_and_grad(loss_fn, argnums=(0,1,2))(params['mlp'], params['controls'], params['u']) updates, state = opt.update(grads, state) params = optax.apply_updates(params, updates) # refine controls with BlackJAX HMC kernel = blackjax.hmc(lambda c: -loss_fn(params['mlp'], c, params['u']), step_size=1e-3, num_integration_steps=20) state = kernel.init(params['controls']) for _ in range(50): state, _ = kernel.step(random.PRNGKey(0), state) return state.position

---------- Main ----------

key = random.PRNGKey(0) mlp_params = init_mlp(key) controls0 = random.normal(key, (steps, 4))*0.1 u0 = jnp.array([0.1, 0.05]) optimal_controls = run_opt(mlp_params, controls0, u0) print("Optimal controls shape:", optimal_controls.shape) print("Flattened for submission:", optimal_controls.flatten())

Want to join this discussion?

Join our community today and start discussing with our members by participating in exciting events, competitions, and challenges. Sign up now to engage with quantum experts!