diff --git a/pyTurbulence/solver.py b/pyTurbulence/solver.py
index 250e4012e211b88bac6ba8c0338f4c43449d2047..42127a96884266ad2ea22b52ec7fee5fc398f4a7 100644
--- a/pyTurbulence/solver.py
+++ b/pyTurbulence/solver.py
@@ -24,7 +24,7 @@ def solve(user_params)->dict:
         - "domain_resolution" (list, tuple, np.ndarray): Resolution of the domain, should be a 3D vector (default: [64, 64, 64]).
         - "Spectrum" (str): Type of spectrum for turbulence generation. Options: "PassotPouquet", "Gaussian" (default: "PassotPouquet").
         - "gamma" (float): Heat capacity ratio, typically between 1 and 2 (default: 1.4).
-        - "case" (int): Dimensionalisation of the density and temperature fluctuations (from Ristorcelli & Blaisdell). Options: 1, 2 (default: 1).
+        - "case" (int): Type of turbulent fields to generate. Options: 0, 1 (default: 1). 0: incompressible - 1: compressible.
         - "output_dir" (str): Directory to store output results (default: "results").
 
     Returns
@@ -60,16 +60,16 @@ def solve(user_params)->dict:
 
     # Generate the solenoidal velocity field
     start = time.time()
-    u, v, w, wmax = compute_solenoidal_velocities(spectrum, nbModes, urms, 
+    us, vs, ws, wmax = compute_solenoidal_velocities(spectrum, nbModes, urms, 
                                                   k0, domain_size, 
                                                   domain_resolution)
-    TKE = np.mean(0.5 * (u.reshape(-1) ** 2 + v.reshape(-1) ** 2 + w.reshape(-1) ** 2))
+    TKE = np.mean(0.5 * (us.reshape(-1) ** 2 + vs.reshape(-1) ** 2 + ws.reshape(-1) ** 2))
     end = time.time()
     solenoidal_time = end - start
 
     # Generate the incompressible pressure fluctuations
     start = time.time()
-    incompressible_pressure_fluctuations = compute_solenoidal_pressure(u, v, w, 
+    incompressible_pressure_fluctuations = compute_solenoidal_pressure(us, vs, ws, 
                                                                        domain_resolution, 
                                                                        domain_size)
     end = time.time()
@@ -77,37 +77,44 @@ def solve(user_params)->dict:
 
     # Generate the dilatational velocity field
     start = time.time()
-    ud, vd, wd = compute_dilatational_velocities(u, v, w, 
+    ud, vd, wd = compute_dilatational_velocities(us, vs, ws, 
                                                  incompressible_pressure_fluctuations, 
                                                  gamma, domain_resolution, domain_size)
-    u += ud*gamma*Mt**2
-    v += vd*gamma*Mt**2
-    w += wd*gamma*Mt**2
     end = time.time()
     dilatational_time = end - start
 
     # Compute the thermodynamic fields
     start = time.time()
-    density, pressure, temperature = compute_thermodynamic_fields(mean_density, 
-                                                                  mean_pressure, 
-                                                                  mean_temperature, 
-                                                                  incompressible_pressure_fluctuations, 
-                                                                  gamma, Mt, params["case"])
+    density, pressure, temperature, u, v, w = compute_thermodynamic_fields(mean_density, 
+                                                                           mean_pressure, 
+                                                                           mean_temperature, 
+                                                                           us,vs,ws,
+                                                                           ud,vd,wd,
+                                                                           incompressible_pressure_fluctuations, 
+                                                                           gamma, Mt, params["case"])
     end = time.time()
     thermodynamic_time = end - start
 
-    # Plot the spectrum
+    # Compute the spectrum
     start = time.time()
-    knyquist, wave_numbers, tke_spectrum = compute_tke_spectrum(u, v, w, domain_size[0], domain_size[1], domain_size[2])
+    knyquist, wave_numbers, tke_spectrum_solenoidal = compute_tke_spectrum(us, vs, ws, 
+                                                                            domain_size[0], 
+                                                                            domain_size[1], domain_size[2])
+
+    _, _,                 tke_spectrum_dilatational = compute_tke_spectrum(ud, vd, wd, 
+                                                                            domain_size[0], 
+                                                                            domain_size[1], domain_size[2])
     real_spectrum = energy_spectrum(spectrum, wave_numbers, urms, k0)
-    spectrum_path = os.path.join(results_dir, "spectrum.pdf")
-    plot_spectrum(wmax, wave_numbers, tke_spectrum, real_spectrum=real_spectrum, save_path=spectrum_path)
+    solenoidal_spectrum_path    = os.path.join(results_dir, "spectrum_solenoidal.pdf")
+    dilatational_spectrum_path  = os.path.join(results_dir, "spectrum_dilatational.pdf")
+    plot_spectrum(wmax, wave_numbers, tke_spectrum_solenoidal,   real_spectrum=real_spectrum, save_path=solenoidal_spectrum_path)
+    plot_spectrum(wmax, wave_numbers, tke_spectrum_dilatational, real_spectrum=real_spectrum, save_path=dilatational_spectrum_path)
     end = time.time()
     spectrum_time = end - start
 
     # Save the spectrum data
     spectrum_data_path = os.path.join(results_dir, "spectrum.dat")
-    np.savetxt(spectrum_data_path, np.column_stack((wave_numbers, tke_spectrum)), header="wave_numbers tke_spectrum")
+    np.savetxt(spectrum_data_path, np.column_stack((wave_numbers, tke_spectrum_solenoidal)), header="wave_numbers tke_spectrum")
 
     # Save the data
     start = time.time()
@@ -141,7 +148,8 @@ def solve(user_params)->dict:
         "pressure": pressure,
         "temperature": temperature,
         "wave_numbers": wave_numbers,
-        "tke_spectrum": tke_spectrum,
+        "tke_spectrum_solenoidal": tke_spectrum_solenoidal,
+        "tke_spectrum_dilatational": tke_spectrum_dilatational,
         "real_spectrum": real_spectrum
     }
 
@@ -203,7 +211,7 @@ def _validate_params(params) -> dict:
     # Define allowed values for specific parameters
     allowed_values = {
         "Spectrum": {"PassotPouquet", "Gaussian"},  # Allowed spectrum types
-        "case": {1, 2},  # Allowed integer cases
+        "case": {0, 1},  # Allowed integer cases
     }
 
     # Define numeric ranges for specific parameters
diff --git a/pyTurbulence/spectrum.py b/pyTurbulence/spectrum.py
index a7ae126a0b9812191f9cabeb1cabd8e1890ed93b..8fbf2a79970c50cb9ad33d41917b92ab5c4611d6 100644
--- a/pyTurbulence/spectrum.py
+++ b/pyTurbulence/spectrum.py
@@ -26,8 +26,8 @@ def energy_spectrum(spectrum: str, k: np.ndarray, urms: float, k0: float) -> np.
     E_k = np.empty_like(k)
     if spectrum == "PassotPouquet":
         E_k = (urms**2 * 16. * np.sqrt(2. / np.pi) * k**4 / k0**5) * np.exp(-2. * (k / k0)**2)
-    if spectrum == "Constant":
-        E_k = urms*k**4*np.exp(-2.*k**2/k0**2)
+    if spectrum == "Gaussian":
+        E_k = 0.00013*k**4*np.exp(-2.*k**2/k0**2)
     return E_k
 
 def compute_tke_spectrum(u : np.ndarray,v : np.ndarray,w : np.ndarray,
@@ -108,7 +108,8 @@ def plot_spectrum(wmax, wave_numbers, tke_spectrum, real_spectrum=None, save_pat
     
     plt.xlabel('Wavenumber (k)')
     plt.ylabel('Energy Spectrum (E(k))')
-    plt.ylim(np.min(tke_spectrum), 1.5*np.max(tke_spectrum))
+    # plt.ylim(np.min(tke_spectrum), 1.5*np.max(tke_spectrum))
+    plt.ylim(1e-6, 1.5*np.max(tke_spectrum))
     plt.legend()
     plt.grid()
     if save_path:
diff --git a/pyTurbulence/syntheticTurbulence.py b/pyTurbulence/syntheticTurbulence.py
index 66a089864666c5def57f7188999aa286afca2e96..a6afbdf35392ea5e0dd1b5d0b7551b0009dfbdf6 100644
--- a/pyTurbulence/syntheticTurbulence.py
+++ b/pyTurbulence/syntheticTurbulence.py
@@ -131,7 +131,7 @@ def compute_solenoidal_pressure(u: np.ndarray,v: np.ndarray,w: np.ndarray,
                                 domain_resolution: Tuple[int,int,int],
                                 domain_size: Tuple[float,float,float]) -> np.ndarray:
     """
-    Compute the solenoidal pressure field from the velocity components.
+    Compute the solenoidal pressure field from the solenoidal velocity components.
 
     Parameters
     ----------
@@ -168,10 +168,14 @@ def compute_solenoidal_pressure(u: np.ndarray,v: np.ndarray,w: np.ndarray,
     return pressure
 
 def compute_thermodynamic_fields(mean_density : float, mean_pressure : float, mean_temperature : float,
+                                 u: np.ndarray, v: np.ndarray, w: np.ndarray, 
+                                 ud: np.ndarray, vd: np.ndarray, wd: np.ndarray,
                                  incompressible_pressure_fluctuations: np.ndarray, 
-                                 gamma: float, Mt : float, case: int)->Tuple[np.ndarray,np.ndarray,np.ndarray]:
+                                 gamma: float, Mt : float, case: int)->Tuple[np.ndarray,np.ndarray,np.ndarray,
+                                                                             np.ndarray,np.ndarray,np.ndarray]:
     """
     Compute the density, pressure and temperature fields from the incompressible pressure fluctuations.
+    Combine solenoidal and dilatational velocity fields to get the total velocity field.
     Eq (4-5) + ansatz from Ristorcelli and Blaisdell : p' = gamma * Mt^2 * p1
     
     Parameters
@@ -182,6 +186,18 @@ def compute_thermodynamic_fields(mean_density : float, mean_pressure : float, me
         Mean pressure.
     mean_temperature : float
         Mean temperature.
+    u : np.ndarray
+        Solenoidal velocity component in the x-direction.
+    v : np.ndarray
+        Solenoidal velocity component in the y-direction.
+    w : np.ndarray
+        Solenoidal velocity component in the z-direction.
+    ud : np.ndarray
+        Dilatational velocity component in the x-direction.
+    vd : np.ndarray
+        Dilatational velocity component in the y-direction.
+    wd : np.ndarray
+        Dilatational velocity component in the z-direction.
     incompressible_pressure_fluctuations : np.ndarray
         Incompressible pressure fluctuations.
     gamma : float
@@ -199,23 +215,29 @@ def compute_thermodynamic_fields(mean_density : float, mean_pressure : float, me
         Pressure field.
     temperature : np.ndarray
         Temperature field.
+    u : np.ndarray
+        Velocity component in the x-direction.
+    v : np.ndarray
+        Velocity component in the y-direction.
+    w : np.ndarray
+        Velocity component in the z-direction
     """
-    compressible_pressure_fluctuations = incompressible_pressure_fluctuations * gamma * Mt**2 
+    overGamma = 1./gamma
 
-    if case == 1:
-        compressible_density_fluctuations = compressible_pressure_fluctuations
-        compressible_temperature_fluctuations = compressible_pressure_fluctuations
+    if case == 0: Mt =0.0 # no fluctuations case
+    compressible_pressure_fluctuations = incompressible_pressure_fluctuations * gamma * Mt**2 
 
-    if case == 2:
-        compressible_density_fluctuations = compressible_pressure_fluctuations*Mt**2
-        compressible_temperature_fluctuations = compressible_pressure_fluctuations*Mt**2*(gamma-1)
+    u += ud*gamma*Mt**2
+    v += vd*gamma*Mt**2
+    w += wd*gamma*Mt**2
 
-    # Get fields of density - temperature - pressure
-    density     = mean_density*(1.+compressible_density_fluctuations)
-    temperature = mean_temperature*(1.+compressible_temperature_fluctuations)
+    # Assuming R = 1 and rho' = overGamma*p'
+    compressible_density_fluctuations = overGamma*compressible_pressure_fluctuations
+    density     = mean_density*(1.+compressible_density_fluctuations) 
     pressure    = mean_pressure*(1.+compressible_pressure_fluctuations)
+    temperature = pressure/density
 
-    return density, pressure, temperature
+    return density, pressure, temperature, u, v, w
 
 def compute_dilatational_velocities(u: np.ndarray,v: np.ndarray,w: np.ndarray,p: np.ndarray,
                                     gamma: float, domain_resolution: Tuple[int,int,int],
@@ -277,7 +299,7 @@ def compute_dilatational_velocities(u: np.ndarray,v: np.ndarray,w: np.ndarray,p:
 
     # Compute dilatation from Eq.7
     # -gamma*dilatation = ddt_p+v_k*dpdxk
-    dilatation = -overGamma*(np.sum(velocity * gradPressure, axis=0) + ddt_p)
+    dilatation = -overGamma*(np.einsum('k...,k...->...',velocity,gradPressure) + ddt_p)
 
     # Compute dilatational velocity from Eq.12
     vd_x, vd_y, vd_z = dilatational_velocities(dilatation, nx, ny, nz, hx, hy, hz)