diff --git a/pyTurbulence/solver.py b/pyTurbulence/solver.py
index 9dbb6c685af65f4b76f6f170eb065cbe7e858775..d3131e7395f15ecac937dd1e782f3109023cbdd7 100644
--- a/pyTurbulence/solver.py
+++ b/pyTurbulence/solver.py
@@ -60,20 +60,26 @@ def solve(user_params)->dict:
 
     # Generate the solenoidal velocity field
     start = time.time()
-    u, v, w, wmax = compute_solenoidal_velocities(spectrum, nbModes, urms, k0, domain_size, domain_resolution)
+    u, v, w, wmax = compute_solenoidal_velocities(spectrum, nbModes, urms, 
+                                                  k0, domain_size, 
+                                                  domain_resolution)
     TKE = np.mean(0.5 * (u.reshape(-1) ** 2 + v.reshape(-1) ** 2 + w.reshape(-1) ** 2))
     end = time.time()
     solenoidal_time = end - start
 
     # Generate the incompressible pressure fluctuations
     start = time.time()
-    incompressible_pressure_fluctuations = compute_solenoidal_pressure(u, v, w, domain_resolution, domain_size)
+    incompressible_pressure_fluctuations = compute_solenoidal_pressure(u, v, w, 
+                                                                       domain_resolution, 
+                                                                       domain_size)
     end = time.time()
     pressure_time = end - start
 
     # Generate the dilatational velocity field
     start = time.time()
-    ud, vd, wd = compute_dilatational_velocities(u, v, w, incompressible_pressure_fluctuations, gamma, domain_resolution, domain_size)
+    ud, vd, wd = compute_dilatational_velocities(u, v, w, 
+                                                 incompressible_pressure_fluctuations, 
+                                                 gamma, domain_resolution, domain_size)
     u += ud*gamma*Mt**2
     v += vd*gamma*Mt**2
     w += wd*gamma*Mt**2
@@ -82,7 +88,11 @@ def solve(user_params)->dict:
 
     # Compute the thermodynamic fields
     start = time.time()
-    density, pressure, temperature = compute_thermodynamic_fields(mean_density, mean_pressure, mean_temperature, incompressible_pressure_fluctuations, gamma, Mt, params["case"])
+    density, pressure, temperature = compute_thermodynamic_fields(mean_density, 
+                                                                  mean_pressure, 
+                                                                  mean_temperature, 
+                                                                  incompressible_pressure_fluctuations, 
+                                                                  gamma, Mt, params["case"])
     end = time.time()
     thermodynamic_time = end - start
 
diff --git a/pyTurbulence/spectrum.py b/pyTurbulence/spectrum.py
index a6d2ac36cf8cbe09a296f46f5cc8744797276209..a7ae126a0b9812191f9cabeb1cabd8e1890ed93b 100644
--- a/pyTurbulence/spectrum.py
+++ b/pyTurbulence/spectrum.py
@@ -30,7 +30,8 @@ def energy_spectrum(spectrum: str, k: np.ndarray, urms: float, k0: float) -> np.
         E_k = urms*k**4*np.exp(-2.*k**2/k0**2)
     return E_k
 
-def compute_tke_spectrum(u,v,w,lx,ly,lz,smooth=True):
+def compute_tke_spectrum(u : np.ndarray,v : np.ndarray,w : np.ndarray,
+                         lx : float,ly : float,lz : float,smooth : bool =True) -> tuple:
     """
     Compute the turbulent kinetic energy (TKE) spectrum from the velocity fields.
 
diff --git a/pyTurbulence/syntheticTurbulence.py b/pyTurbulence/syntheticTurbulence.py
index b9d1150d5c5cb98d55f5a118a71170a302b0d490..e2b7dfef178fba5467eaf58ecce11c4981be269b 100644
--- a/pyTurbulence/syntheticTurbulence.py
+++ b/pyTurbulence/syntheticTurbulence.py
@@ -5,7 +5,9 @@ from numba import jit, prange
 from typing import Tuple
 
 @jit(nopython=True, parallel=True)
-def _compute_velocity_field(u_, v_, w_, nx, ny, nz, kx, ky, kz, xc, yc, zc, um, sxm, sym, szm, dx, dy, dz, psi):
+def _compute_velocity_field(u_, v_, w_, nx, ny, nz, kx, ky, kz, 
+                            xc, yc, zc, um, sxm, sym, szm, 
+                            dx, dy, dz, psi):
     for k in prange(nz):  
         for j in prange(ny):
             for i in prange(nx):
@@ -172,7 +174,9 @@ def compute_solenoidal_pressure(u: np.ndarray,v: np.ndarray,w: np.ndarray,
     pressure = poisson_3d(rhs, nx, ny, nz, hx, hy, hz)
     return pressure
 
-def compute_thermodynamic_fields(mean_density, mean_pressure, mean_temperature, incompressible_pressure_fluctuations, gamma, Mt, case):
+def compute_thermodynamic_fields(mean_density : float, mean_pressure : float, mean_temperature : float,
+                                 incompressible_pressure_fluctuations: np.ndarray, 
+                                 gamma: float, Mt : float, case: int)->Tuple[np.ndarray,np.ndarray,np.ndarray]:
     """
     Compute the density, pressure and temperature fields from the incompressible pressure fluctuations.
     Eq (4-5) + ansatz from Ristorcelli and Blaisdell : p' = gamma * Mt^2 * p1
@@ -259,48 +263,33 @@ def compute_dilatational_velocities(u: np.ndarray,v: np.ndarray,w: np.ndarray,p:
     hy = domain_size[1]/ny
     hz = domain_size[2]/nz
 
+    overGamma = 1./gamma
+
     velocity       = np.array([u,v,w])
-    dudx,dudy,dudz = np.gradient(u,hx,hy,hz,edge_order=2)
-    dvdx,dvdy,dvdz = np.gradient(v,hx,hy,hz,edge_order=2)
-    dwdx,dwdy,dwdz = np.gradient(w,hx,hy,hz,edge_order=2)
-    gradVelocity   = np.array([[dudx,dudy,dudz],
-                               [dvdx,dvdy,dvdz],
-                               [dwdx,dwdy,dwdz]])
-    dpdx,dpdy,dpdz = np.gradient(p,hx,hy,hz,edge_order=2)
-    gradPressure   = np.array([dpdx,dpdy,dpdz])
+    gradVelocity   = np.array([np.gradient(vel, hx, hy, hz, edge_order=2) for vel in velocity])
+    gradPressure   = np.gradient(p,hx,hy,hz,edge_order=2)
 
     # Compute time derivative of pressure fluctuations (Eq.11)
-    #rhs = 2*((u_k * du_i/dx_k + dp/dx_i) * u_j)
-    rhs = np.zeros([nx,ny,nz])
-    for i in range(3):
-        for j in range(3):
-            for k in range(3):
-                rhs += 2.*(velocity[k]*gradVelocity[i][k]+gradPressure[i])*velocity[j]
+    #rhs = ((u_k * du_i/dx_k + dp/dx_i) * u_j)
+    rhs = np.einsum('k...,ik...->i...', velocity, gradVelocity) + gradPressure
+    rhs = np.einsum('i...,j...->...', rhs, velocity)
 
     # compute second derivative of rhs 
-    # (p,t),jj = (rhs),ij
-    drhsdx,drhsdy,drhsdz = np.gradient(rhs,hx,hy,hz,edge_order=2)
-    drhsdxdx, drhsdxdy, drhsdxdz = np.gradient(drhsdx,hx,hy,hz,edge_order=2)
-    drhsdydx, drhsdydy, drhsdydz = np.gradient(drhsdy,hx,hy,hz,edge_order=2)
-    drhsdzdx, drhsdzdy, drhsdzdz = np.gradient(drhsdz,hx,hy,hz,edge_order=2)
-
-    hessian_rhs = np.array([[drhsdxdx,drhsdxdy,drhsdxdz],
-                            [drhsdydx,drhsdydy,drhsdydz],
-                            [drhsdzdx,drhsdzdy,drhsdzdz]])
-    rhs_ij = np.zeros([nx,ny,nz])
-    for i in range(3):
-        for j in range(3):
-            rhs_ij += hessian_rhs[i][j]
+    # (p,t),jj = 2.*(rhs),ij
+    grad_rhs    = np.gradient(rhs, hx, hy, hz, edge_order=2)
+    hessian_rhs = np.array([np.gradient(grad, hx, hy, hz, edge_order=2) for grad in grad_rhs])
+    rhs_ij      = np.einsum('ij...->...', hessian_rhs) 
+    rhs_ij      *= 2.0
+
+    print(f"mean rhs_ij = {np.mean(rhs_ij)}")
 
     ddt_p = poisson_3d(rhs_ij, nx, ny, nz, hx, hy, hz)
 
     # Compute dilatation from Eq.7
     # -gamma*dilatation = ddt_p+v_k*dpdxk
-    dilatation = np.zeros_like(p)
-    for i in range(3):
-        dilatation += velocity[i]*gradPressure[i]
-    dilatation += ddt_p
-    dilatation /= -gamma
+    dilatation = -overGamma*(np.sum(velocity * gradPressure, axis=0) + ddt_p)
+
+    print(f"mean dilatation = {np.mean(dilatation)}")
 
     # Compute dilatational velocity from Eq.12
     vd_x, vd_y, vd_z = dilatational_velocities(dilatation, nx, ny, nz, hx, hy, hz)