diff --git a/example.py b/example.py
index ae550090a90d050d8ba879e0228ee6ebf5edbdf3..1b72f8b78ad6bcd38a1fde0212ef9099cbcc4d1f 100644
--- a/example.py
+++ b/example.py
@@ -1,71 +1,33 @@
 import numpy as np
 from pyTurbulence.solver import solve
-from datetime import datetime
+from pyTurbulence.plot3D import plot3D
 import os
 
 def main():
-    # Metadata
-    code_name = "~~~ pyTurbulence: Synthetic Turbulence Generator ~~~"
-    author = "Amaury Bilocq"
-    date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+
 
     # Parameters
     params = {
-        "nbModes": 512,
+        "nbModes": 2048,
         "Mt": 0.23,
         "Pressure": 1.0,
         "Temperature": 1.0,
         "k0": 12.0,
         "domain_size": (2. * np.pi, 2. * np.pi, 2. * np.pi),
-        "domain_resolution": (64, 64, 64),
+        "domain_resolution": (64,64,64),
         "seed": 42,
         "gamma": 1.4,
         "case": 1,
-        "Spectrum": "PassotPouquet"
+        "Spectrum": "PassotPouquet",
+        "output_dir": "results"
     }
 
-    # Create results directory
-    results_dir = "results"
-    if not os.path.exists(results_dir):
-        os.makedirs(results_dir)
-
     # Run the solver
-    results = solve(params, results_dir)
-
-    # Log file
-    log_file = os.path.join(results_dir, "log.txt")
-    with open(log_file, "w") as f:
-        # Display metadata
-        f.write(f"{'='*70}\n")
-        f.write(f"{code_name:^70}\n")
-        f.write(f"{'='*70}\n")
-        f.write(f"{'Author:':<20} {author}\n")
-        f.write(f"{'Date:':<20} {date}\n")
-        f.write(f"{'='*70}\n")
-
-        # Display parameters
-        f.write(f"{'Parameters':<20}\n")
-        f.write(f"{'-'*70}\n")
-        for key, value in params.items():
-            f.write(f"{key:<20} : {value}\n")
-        f.write(f"{'='*70}\n")
-
-        # Display results
-        f.write(f"{'Task':<50} {'Time (seconds)':>20}\n")
-        f.write(f"{'-'*70}\n")
-        f.write(f"{'Computing the solenoidal velocity field':<50} {results['solenoidal_time']:>20.4f}\n")
-        f.write(f"{'Computing the incompressible pressure fluctuations':<50} {results['pressure_time']:>20.4f}\n")
-        f.write(f"{'Computing the thermodynamic fields':<50} {results['thermodynamic_time']:>20.4f}\n")
-        f.write(f"{'Plotting the spectrum':<50} {results['spectrum_time']:>20.4f}\n")
-        f.write(f"{'Saving the data':<50} {results['save_time']:>20.4f}\n")
-        f.write(f"{'-'*70}\n")
-        f.write(f"{'Total TKE':<50} {results['TKE']:>20.4f}\n")
-        f.write(f"{'Expected TKE':<50} {results['Ek']:>20.4f}\n")
-        f.write(f"{'='*70}\n")
+    data = solve(params)
 
-    # Print log to console
-    with open(log_file, "r") as f:
-        print(f.read())
+    # Plot the 3D fields
+    plot3D(data["u"], params["domain_size"], 
+           params["domain_resolution"], name="u")
 
 if __name__ == "__main__":
     main()
\ No newline at end of file
diff --git a/pyTurbulence/plot3D.py b/pyTurbulence/plot3D.py
index 8d6d5a42a6440f17c674316f524a262c8c8c6f58..c0045aa14f19c2673ca74c73e688f4001fda4d8e 100644
--- a/pyTurbulence/plot3D.py
+++ b/pyTurbulence/plot3D.py
@@ -1,7 +1,6 @@
 
 import numpy as np
 import matplotlib.pyplot as plt
-from mpl_toolkits.mplot3d import Axes3D
 
 def plot3D(data, domain_size, domain_resolution,name):
     """
diff --git a/pyTurbulence/poissonSolver.py b/pyTurbulence/poissonSolver.py
index 5ee47ab50baa06827d3e2d46bd691a64e8ca8685..e678038ff39a74330d3c9f1f04dacf1438885dd8 100644
--- a/pyTurbulence/poissonSolver.py
+++ b/pyTurbulence/poissonSolver.py
@@ -1,6 +1,6 @@
 import numpy as np
 
-def poisson_2d(f, nx, ny, hx, hy):
+def poisson_2d(f, nx, ny, hx, hy)->np.ndarray:
     """
     Solve a 2D Poisson equation with periodic boundary conditions.
     
@@ -48,7 +48,7 @@ def poisson_2d(f, nx, ny, hx, hy):
     res[:, ny-1] = res[:, 0]
     return res
 
-def poisson_3d(f, nx, ny, nz, hx, hy, hz):
+def poisson_3d(f, nx, ny, nz, hx, hy, hz)->np.ndarray:
     """
     Solve a 3D Poisson equation with periodic boundary conditions.
     
diff --git a/pyTurbulence/solver.py b/pyTurbulence/solver.py
index a5f3201ec3ec6e8c38d1d4c3dca1d09ca8b1db9c..4228e3f02b98c441fcba1c4454bfefecc7c28448 100644
--- a/pyTurbulence/solver.py
+++ b/pyTurbulence/solver.py
@@ -2,25 +2,49 @@ import numpy as np
 from pyTurbulence.syntheticTurbulence import compute_solenoidal_velocities, compute_solenoidal_pressure, compute_thermodynamic_fields
 from pyTurbulence.spectrum import compute_tke_spectrum, plot_spectrum, energy_spectrum
 from pyTurbulence.plot3D import plot3D
+from datetime import datetime
 import time
 import os
 
-def solve(params, results_dir):
+def solve(user_params)->dict:
+    """
+    Solve the synthetic turbulence problem.
+
+    Parameters
+    ----------
+    user_params : dict
+        Dictionary containing the simulation parameters.
+
+    Returns
+    -------
+    results : dict
+        Dictionary containing the results of the simulation.
+    """    
+
+    # Validate the input parameters
+    params = validate_params(user_params)
+
     # Read the params dictionary
     nbModes = params["nbModes"]
     Mt = params["Mt"]
     mean_pressure = params["Pressure"]
     mean_temperature = params["Temperature"]
-    mean_density = mean_pressure / mean_temperature
-    gamma = params["gamma"]
-    celerity = np.sqrt(gamma * mean_pressure / mean_density)
-    Ek = 0.5 * (Mt * celerity) ** 2
-    urms = np.sqrt(2. * Ek / 3.)
     k0 = params["k0"]
     domain_size = params["domain_size"]
     domain_resolution = params["domain_resolution"]
-    xi = gamma * Mt ** 2
     spectrum = params["Spectrum"]
+    gamma = params["gamma"]
+    results_dir = params["output_dir"]
+
+    # Create results directory
+    if not os.path.exists(results_dir):
+        os.makedirs(results_dir)
+
+    # Compute additional parameters
+    mean_density = mean_pressure / mean_temperature
+    celerity = np.sqrt(gamma * mean_pressure / mean_density)
+    Ek = 0.5 * (Mt * celerity) ** 2
+    urms = np.sqrt(2. * Ek / 3.)
 
     # Generate the solenoidal velocity field
     start = time.time()
@@ -50,9 +74,9 @@ def solve(params, results_dir):
     end = time.time()
     spectrum_time = end - start
 
-    # Plot the 3D fields
-    # plot3D(u, domain_size, domain_resolution, name="u")
-
+    # Save the spectrum data
+    spectrum_data_path = os.path.join(results_dir, "spectrum.dat")
+    np.savetxt(spectrum_data_path, np.column_stack((wave_numbers, tke_spectrum)), header="wave_numbers tke_spectrum")
 
     # Save the data
     start = time.time()
@@ -64,7 +88,7 @@ def solve(params, results_dir):
     end = time.time()
     save_time = end - start
 
-    return {
+    results = {
         "TKE": TKE,
         "Ek": Ek,
         "solenoidal_time": solenoidal_time,
@@ -72,4 +96,192 @@ def solve(params, results_dir):
         "thermodynamic_time": thermodynamic_time,
         "spectrum_time": spectrum_time,
         "save_time": save_time
-    }
\ No newline at end of file
+    }
+
+    # Write the log file
+    write_log_file(params, results)
+
+    # Dictionary containing the results of the simulation
+    data = {
+        "u": u,
+        "v": v,
+        "w": w,
+        "pressure": pressure,
+        "temperature": temperature,
+        "wave_numbers": wave_numbers,
+        "tke_spectrum": tke_spectrum,
+        "real_spectrum": real_spectrum
+    }
+
+
+    return data
+
+def validate_params(params) -> dict:
+    """
+    Validate the input parameters dictionary, set default values where needed, enforce allowed values, and check numeric ranges.
+
+    Parameters
+    ----------
+    params : dict
+        Dictionary containing the simulation parameters.
+
+    Returns
+    -------
+    dict
+        Validated parameters with defaults set where necessary.
+
+    Raises
+    ------
+    TypeError
+        If a parameter has an incorrect type.
+    ValueError
+        If a list parameter does not have the correct length or an invalid value.
+    """
+
+    # Define required parameters with default values
+    default_params = {
+        "nbModes": 128,  
+        "Mt": 0.2,  
+        "Pressure": 1.0,  
+        "Temperature": 1.0,  
+        "k0": 4,  
+        "domain_size": [2.0 * np.pi, 2.0 * np.pi, 2.0 * np.pi],  
+        "domain_resolution": [64, 64, 64],  
+        "Spectrum": "PassotPouquet",  
+        "gamma": 1.4,  
+        "seed": 42,  
+        "case": 1,
+        "output_dir": "results"
+    }
+
+    # Define expected types for validation
+    expected_types = {
+        "nbModes": int,
+        "Mt": (int, float),
+        "Pressure": (int, float),
+        "Temperature": (int, float),
+        "k0": (int, float),
+        "domain_size": (list, tuple, np.ndarray),
+        "domain_resolution": (list, tuple, np.ndarray),
+        "Spectrum": str,
+        "gamma": (int, float),
+        "seed": int,
+        "case": int,
+        "output_dir": str
+    }
+
+    # Define allowed values for specific parameters
+    allowed_values = {
+        "Spectrum": {"PassotPouquet", "Gaussian", "Exponential"},  # Allowed spectrum types
+        "case": {0, 1, 2},  # Allowed integer cases
+    }
+
+    # Define numeric ranges for specific parameters
+    numeric_ranges = {
+        "nbModes": (1, None),  # Must be >= 1
+        "Mt": (0, None),  # Typical turbulent Mach number is between 0 and 1
+        "Pressure": (0, None),  # Pressure must be positive
+        "Temperature": (0, None),  # Temperature must be positive
+        "k0": (1, None),  # Characteristic wavenumber must be >= 1
+        "gamma": (0, None),  # Heat capacity ratio is typically between 1 and 2
+    }
+
+    validated_params = default_params.copy()
+
+    for param, expected_type in expected_types.items():
+        if param not in params:
+            print(f"⚠️ Warning: '{param:<20}' is missing. Using default value: {default_params[param]}")
+        else:
+            validated_params[param] = params[param]  # Overwrite with user-provided value
+
+            # Type validation
+            if not isinstance(validated_params[param], expected_type):
+                raise TypeError(
+                    f"Error: Incorrect type for '{param}'. "
+                    f"Expected {expected_type}, got {type(validated_params[param])}."
+                )
+
+            # Allowed values validation
+            if param in allowed_values and validated_params[param] not in allowed_values[param]:
+                raise ValueError(
+                    f"Error: Invalid value for '{param}'. Allowed values: {allowed_values[param]}. "
+                    f"Got '{validated_params[param]}'."
+                )
+
+            # Numeric range validation
+            if param in numeric_ranges:
+                min_val, max_val = numeric_ranges[param]
+                if min_val is not None and validated_params[param] < min_val:
+                    raise ValueError(
+                        f"Error: '{param}' must be >= {min_val}. Got {validated_params[param]}."
+                    )
+                if max_val is not None and validated_params[param] > max_val:
+                    raise ValueError(
+                        f"Error: '{param}' must be <= {max_val}. Got {validated_params[param]}."
+                    )
+
+    # Additional checks for array-like parameters
+    if len(validated_params["domain_size"]) != 3:
+        raise ValueError("Error: 'domain_size' must be a list or tuple of length 3.")
+
+    if len(validated_params["domain_resolution"]) != 3:
+        raise ValueError("Error: 'domain_resolution' must be a list or tuple of length 3.")
+
+    return validated_params
+
+def get_metadata()->dict:
+    metadata = {
+        "code_name": "~~~ pyTurbulence: Synthetic Turbulence Generator ~~~",
+        "author": "Amaury Bilocq",
+        "date": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+    }
+    return metadata
+
+def write_log_file(params, results):
+    """
+    Write the log file with metadata, parameters, and results.
+
+    Parameters
+    ----------
+    log_file : str
+        Path to the log file.
+    params : dict
+        Dictionary containing the simulation parameters.
+    results : dict
+        Dictionary containing the results of the simulation.
+    """
+    log_file = os.path.join(params["output_dir"], "log.txt")
+    metadata = get_metadata()
+
+    with open(log_file, "w") as f:
+        # Display metadata
+        f.write(f"{'='*70}\n")
+        f.write(f"{metadata['code_name']:^70}\n")
+        f.write(f"{'='*70}\n")
+        f.write(f"{'Author:':<20} {metadata['author']}\n")
+        f.write(f"{'Date:':<20} {metadata['date']}\n")
+        f.write(f"{'='*70}\n")
+
+        # Display parameters
+        f.write(f"{'Parameters':<20}\n")
+        f.write(f"{'-'*70}\n")
+        for key, value in params.items():
+            f.write(f"{key:<20} : {value}\n")
+        f.write(f"{'='*70}\n")
+
+        # Display results
+        f.write(f"{'Task':<50} {'Time (seconds)':>20}\n")
+        f.write(f"{'-'*70}\n")
+        f.write(f"{'Computing the solenoidal velocity field':<50} {results['solenoidal_time']:>20.4f}\n")
+        f.write(f"{'Computing the incompressible pressure fluctuations':<50} {results['pressure_time']:>20.4f}\n")
+        f.write(f"{'Computing the thermodynamic fields':<50} {results['thermodynamic_time']:>20.4f}\n")
+        f.write(f"{'Plotting the spectrum':<50} {results['spectrum_time']:>20.4f}\n")
+        f.write(f"{'Saving the data':<50} {results['save_time']:>20.4f}\n")
+        f.write(f"{'-'*70}\n")
+        f.write(f"{'Total TKE':<50} {results['TKE']:>20.4f}\n")
+        f.write(f"{'Expected TKE':<50} {results['Ek']:>20.4f}\n")
+        f.write(f"{'='*70}\n")
+
+    # Print log to console
+    with open(log_file, "r") as f:
+        print(f.read())
diff --git a/pyTurbulence/spectrum.py b/pyTurbulence/spectrum.py
index a9af613be1f0b24501d9758fd38921a4c15246cc..a4ea495c49d48cc80ee3de468ab86a089f365822 100644
--- a/pyTurbulence/spectrum.py
+++ b/pyTurbulence/spectrum.py
@@ -1,6 +1,6 @@
 import numpy as np
 import matplotlib.pyplot as plt
-from numba import njit, typeof
+from numba import njit
 from scipy import stats
 
 @njit
@@ -102,10 +102,11 @@ def plot_spectrum(wmax, wave_numbers, tke_spectrum, real_spectrum=None, save_pat
     
     plt.xlabel('Wavenumber (k)')
     plt.ylabel('Energy Spectrum (E(k))')
+    plt.ylim(np.min(tke_spectrum), 1.5*np.max(tke_spectrum))
     plt.legend()
     plt.grid()
-    
     if save_path:
         plt.savefig(save_path)
     else:
         plt.show()
+    plt.close()