Khalamendyk Ivan

Khalamendyk Ivan

Joined August 12, 2025

Karma 0

Khalamendyk Ivan

1

Posted by Lubitelua

---------- Fidelity & topo ----------

def fidelity(psi_final, target): return jnp.abs(dq.expect(dq.dag(target) @ psi_final, psi_final))**2 def topo_loss(psi_t): phi = jnp.angle(psi_t) winding = jnp.mean(jnp.abs(jnp.diff(phi))) return (winding - 2.0)**2 # очікуємо winding ~2 для Fock |2>

---------- Full loss ----------

x_grid = jnp.linspace(-5, 5, 16) @jit def loss_fn(mlp_params, controls, u): phys = 0.0 for i, t in enumerate(tsave): phys += physics_residual(mlp_params, x_grid, t, controls[i], u) phys /= steps psi_final = vmap(lambda t, ctrl: vmap(lambda x: psi_pinn(mlp_params, x, t, ctrl))(x_grid))(tsave, controls) fid = 1 - fidelity(psi_final[-1].mean(), target_state) topo = topo_loss(psi_final[-1]) reg = 1e-3 * jnp.sum(controls**2) return fid + 0.1phys + 0.01topo + reg

---------- Optimization ----------

def run_opt(mlp_params, controls0, u0, steps=200, lr=1e-3): opt = optax.adam(lr) params = {'mlp': mlp_params, 'controls': controls0, 'u': u0} state = opt.init(params) best = params for i in range(steps): val, grads = jax.value_and_grad(loss_fn, argnums=(0,1,2))(params['mlp'], params['controls'], params['u']) updates, state = opt.update(grads, state) params = optax.apply_updates(params, updates) # refine controls with BlackJAX HMC kernel = blackjax.hmc(lambda c: -loss_fn(params['mlp'], c, params['u']), step_size=1e-3, num_integration_steps=20) state = kernel.init(params['controls']) for _ in range(50): state, _ = kernel.step(random.PRNGKey(0), state) return state.position

---------- Main ----------

key = random.PRNGKey(0) mlp_params = init_mlp(key) controls0 = random.normal(key, (steps, 4))*0.1 u0 = jnp.array([0.1, 0.05]) optimal_controls = run_opt(mlp_params, controls0, u0) print("Optimal controls shape:", optimal_controls.shape) print("Flattened for submission:", optimal_controls.flatten())
Khalamendyk Ivan

1

Posted by Lubitelua

--- Revolutionary ψ-PINN for Dynamiqs-2025-2: Optimal Control ---
import jax, jax.numpy as jnp, optax
from jax import grad, jit, random, vmap
import numpy as np
import blackjax
import dynamiqs as dq # базова бібліотека конкурсу

---------- System setup ----------

Na, Nq = 20, 2
a = dq.destroy(Na) # cavity
sigma_x = dq.pauli_x(Nq) # qubit
target_state = dq.tensor(dq.fock(Na, 2), dq.fock(Nq, 0)) # |2,g>
psi0 = dq.tensor(dq.fock(Na, 0), dq.fock(Nq, 0)) # |0,g>
steps = 100
tsave = jnp.linspace(0, 10, steps)

---------- MLP for ψ-PINN ----------

def mlp(params, inputs):
for w, b in params[:-1]:

    inputs = jnp.tanh(inputs @ w + b)

w, b = params[-1]

return inputs @ w + b  # [...,1]
def init_mlp(key, layers=[3+4, 32, 16, 1]): # input: (x,t,4 controls)
keys = random.split(key, len(layers)-1)

params = []

for kin, kout, k in zip(layers[:-1], layers[1:], keys):

    w = random.normal(k, (kin, kout)) / jnp.sqrt(kin)

    b = jnp.zeros(kout)

    params.append((w, b))

return params
def psi_pinn(params, x, t, ctrl): # ctrl [4]
inp = jnp.concatenate([jnp.array([x, t]), ctrl])

return mlp(params, inp[None, :])[0]

---------- Physics residual ----------

λ = 3 / (2*jnp.pi)
def physics_residual(params, x_grid, t, ctrl, u):
psi_fn = lambda x: psi_pinn(params, x, t, ctrl)

psi = vmap(psi_fn)(x_grid)

dpsi_dt = grad(lambda tt: psi_pinn(params, x_grid[0], tt, ctrl))(t)

d2psi_dx2 = grad(grad(lambda xx: psi_pinn(params, xx, t, ctrl)))(x_grid[0])

g2, kappa = u

res = d2psi_dx2 + dpsi_dt + psi + λ*psi**3 + g2*psi*ctrl[0] - kappa*psi

return jnp.mean(res**2)

Khalamendyk Ivan

1

Posted by Lubitelua

---------- Loss function ----------

w_t = jnp.exp(-tsave/0.5); w_t /= w_t.sum()
x_grid = jnp.linspace(-5, 5, 32)
def loss_fn(params, u):
data_loss, phys_loss, topo_loss = 0.0, 0.0, 0.0

for b in range(B):

    alpha = alphas[b]

    sim = []

    for t in tsave:

        psi_vals = vmap(lambda x: psi_pinn(params, x, t, alpha))(x_grid)

        sim.append(jnp.mean(jnp.abs(psi_vals)**2))

    sim = jnp.array(sim)

    err = sim - data[:, b]

    data_loss += jnp.mean(w_t * err**2)

    # physics + topo penalties

    phys_loss += jnp.mean([physics_residual(params, x, t, alpha, u) for x in x_grid for t in tsave])

    topo_loss += jnp.mean(jnp.gradient(sim) - jnp.gradient(data[:, b]))

return data_loss + 0.1*phys_loss + 0.01*topo_loss

---------- Training ----------

def run_opt(params, u, steps=200, lr=1e-3):
opt = optax.adam(lr)

state = opt.init((params,u))

best_val, best = 1e9, (params,u)

for i in range(steps):

    val, grads = jax.value_and_grad(loss_fn, argnums=(0,1))(params, u)

    updates, state = opt.update(grads, state)

    params, u = optax.apply_updates((params,u), updates)

    if val < best_val:

        best_val, best = val, (params,u)

return best

---------- Example run ----------

key = random.PRNGKey(0)
params = init_mlp(key)
u0 = jnp.array([0.1, 0.01, 0.05, 1e-3, 1e-3])
best_params, best_u = run_opt(params, u0)
print("Ψ-PINN best params:", best_u)
Khalamendyk Ivan

1

Posted by Lubitelua

--- Ψ-PINN: physics + data + topology ---

import jax, jax.numpy as jnp, optax from jax import grad, random, vmap import numpy as np

---------- Load experiment data ----------

def _load_any(name): import pathlib for p in [f".aqora/data/data/{name}", f".aqora/data/{name}", f"data/{name}", name]: if pathlib.Path(p).exists(): return np.load(p, allow_pickle=True) raise FileNotFoundError(name)
data = jnp.array(_load_any("short.npy"), dtype=jnp.float64) # [T,B] tsave = jnp.array(_load_any("time_short.npy"), dtype=jnp.float64) # [T] alphas = jnp.sqrt(data[0]) # coherent amplitudes
T, B = data.shape[0], data.shape[1]

---------- Simple MLP (PINN) ----------

def mlp(params, x): for w, b in params: x = jnp.tanh(jnp.dot(x, w) + b) return x
def init_mlp(key, layers=[2, 32, 32, 1]): keys = random.split(key, len(layers)-1) params = [] for kin, kout, k in zip(layers[:-1], layers[1:], keys): w = random.normal(k, (kin, kout)) / jnp.sqrt(kin) b = jnp.zeros(kout) params.append((w, b)) return params

---------- ψ-PINN forward ----------

def psi_pinn(params, x, t, alpha): inp = jnp.array([x, t]) return alpha * mlp(params, inp)

---------- Physics residual ----------

lambda_ = 3 / (2 * jnp.pi)
def physics_residual(params, x, t, alpha, u): psi_fn = lambda xt: psi_pinn(params, xt[0], xt[1], alpha) psi = psi_fn((x, t)) dpsi_dt = grad(psi_fn, argnums=1)((x, t)) d2psi_dx2 = grad(grad(psi_fn, argnums=0), argnums=0)((x, t)) g2, kappa_a, kappa_b, nth_a, nth_b = u res = d2psi_dx2 + dpsi_dt + psi + lambda_ * psi3 res += -kappa_a psi - kappa_b(psi - nth_bjnp.mean(psi)) - 1jg2*psi2 return jnp.real(res)**2