import React, { useEffect, useState, useRef } from 'react';
import { Container, Grid, Paper, Button, Typography, IconButton, CircularProgress, AppBar, Toolbar } from '@mui/material';

import UploadIcon from '@mui/icons-material/Upload';

import ArrowBackIcon from '@mui/icons-material/ArrowBack';
import * as tf from '@tensorflow/tfjs';
import * as knnClassifier from '@tensorflow-models/knn-classifier'; // Updated import
import { useNavigate } from 'react-router-dom';

const TARGET_LENGTH = 16000; // Example target length of audio clip
const REQUIRED_SIZE = 16000;

const AudioClassifier = ({ name, description, type }) => {
  const [classes, setClasses] = useState([
    { id: 1, name: 'Class 1', color: '#e57373', audio: [] },
    { id: 2, name: 'Class 2', color: '#81c784', audio: [] }
  ]);
  const [isPredicting, setIsPredicting] = useState(false);
  const [isTraining, setIsTraining] = useState(false);
  const [isModelTrained, setIsModelTrained] = useState(false);
  const [classifier, setClassifier] = useState(knnClassifier.create());
  const [predictionResult, setPredictionResult] = useState(null);
  const audioRef = useRef(null);
  const audioContextRef = useRef(null);
  const microphoneRef = useRef(null);
  const audioDataBufferRef = useRef([]); // Store audio data for prediction
  const navigate = useNavigate();

  // Initialize audio context and microphone
  useEffect(() => {
    const initAudio = async () => {
      try {
        audioContextRef.current = new (window.AudioContext || window.webkitAudioContext)();
        microphoneRef.current = await navigator.mediaDevices.getUserMedia({ audio: true });
        audioRef.current = audioContextRef.current.createMediaStreamSource(microphoneRef.current);
        audioContextRef.current.resume(); // Resume the audio context after user gesture
      } catch (error) {
        console.error("Audio initialization error:", error);
        alert("Microphone access is required to use this feature. Please check your settings.");
      }
    };
  
    initAudio();
  }, []);  

  const handleBackClick = () => {
    navigate('../'); // Navigate back to the previous page
  };

  // Preprocess audio buffer
  const preprocessAudioBuffer = (audioBuffer) => {
    const channelData = audioBuffer.getChannelData(0);
    if (channelData.length > TARGET_LENGTH) {
      return channelData.slice(0, TARGET_LENGTH);
    } else {
      const paddedAudio = new Float32Array(TARGET_LENGTH);
      paddedAudio.set(channelData);
      return paddedAudio;
    }
  };

  // Handle audio upload
  const handleAudioUpload = (e, classId) => {
    const file = e.target.files[0];
    if (file) {
      const audioBlob = new Blob([file], { type: 'audio/wav' });
      setClasses(prevClasses =>
        prevClasses.map(cls =>
          cls.id === classId
            ? { ...cls, audio: [...cls.audio, audioBlob] }
            : cls
        )
      );
    }
  };

  // Add audio to classifier
  const addAudioToClassifier = (audioBlob, classId) => {
    const reader = new FileReader();
    reader.onloadend = async () => {
      const audioBuffer = await audioContextRef.current.decodeAudioData(reader.result);
      const processedAudio = preprocessAudioBuffer(audioBuffer);
      const audioTensor = tf.tensor(processedAudio).expandDims(0); // Shape: [1, TARGET_LENGTH]
      classifier.addExample(audioTensor, classId - 1);  // KNN class starts from 0
      setClasses(prevClasses =>
        prevClasses.map(cls =>
          cls.id === classId ? { ...cls, audio: [...cls.audio, audioBlob] } : cls
        )
      );
    };
    reader.readAsArrayBuffer(audioBlob);
  };

  // Record audio for training
  const recordAudio = async (classId) => {
    const mediaRecorder = new MediaRecorder(microphoneRef.current);
    const audioChunks = [];

    mediaRecorder.ondataavailable = event => {
      audioChunks.push(event.data);
    };

    mediaRecorder.onstop = () => {
      const audioBlob = new Blob(audioChunks, { type: 'audio/wav' });
      addAudioToClassifier(audioBlob, classId);
      alert(`Audio recorded for ${classes.find(cls => cls.id === classId).name}`);
    };

    mediaRecorder.start();
    setTimeout(() => {
      mediaRecorder.stop();
    }, 3000); // Adjust recording duration as needed
  };

  // Train the model
  const handleTraining = async () => {
    setIsTraining(true);
    await new Promise(resolve => setTimeout(resolve, 3000));
    setIsTraining(false);
    setIsModelTrained(true);
    alert('Training Complete!');
  };

  // Predict live audio
  const predictLiveAudio = async () => {
    const audioDataBuffer = audioDataBufferRef.current;

    if (classifier && audioDataBuffer.length >= REQUIRED_SIZE) {
      const trimmedData = audioDataBuffer.slice(-REQUIRED_SIZE); // Use the last REQUIRED_SIZE elements
      const audioTensor = tf.tensor(trimmedData).reshape([1, REQUIRED_SIZE]); // Reshape for model input

      try {
        const prediction = await classifier.predictClass(audioTensor);

        // Dynamically handle number of classes by using the size of the confidences array
        const confidencesArray = Object.values(prediction.confidences);

        // Set a confidence threshold (if needed)
        const confidenceThreshold = 0.5;

        // Filter and find the class with the highest confidence
        const validConfidences = confidencesArray.map((conf, index) => {
          return conf >= confidenceThreshold ? { index, conf } : { index, conf: 0 };
        });

        // Find the class index with the highest confidence
        const predictedClass = validConfidences.reduce((prev, curr) => {
          return (curr.conf > prev.conf) ? curr : prev;
        });

        if (predictedClass && predictedClass.conf > 0) {
          const predictedClassIndex = predictedClass.index;

          // Dynamically set the class label based on classIndex + 1 (or however your labels are structured)
          setPredictionResult({
            classIndex: predictedClassIndex + 1,  // Adjust the class index (if labels start from 1)
            label: `Class ${predictedClassIndex + 1}`,  // Adjust the label
            confidences: prediction.confidences,  // Store the confidence values
          });
        } else {
          setPredictionResult({
            classIndex: null,
            label: 'Prediction uncertain',
            confidences: prediction.confidences,
          });
        }

      } catch (error) {
        console.error("Prediction error: ", error);
        setPredictionResult({
          classIndex: null,
          label: 'Prediction failed',
          confidences: {},
        });
      } finally {
        audioTensor.dispose(); // Dispose the tensor after prediction
      }

      audioDataBufferRef.current = []; // Reset buffer after prediction
    } else {
      console.log("Not enough audio data for prediction: ", audioDataBuffer.length);
    }
  };


  useEffect(() => {
    const processAudio = () => {
      if (audioRef.current && audioContextRef.current) {
        const analyser = audioContextRef.current.createAnalyser();
        audioRef.current.connect(analyser);
        analyser.fftSize = 2048;
        const bufferLength = analyser.frequencyBinCount;
        const dataArray = new Float32Array(bufferLength);

        const collectAudioData = () => {
          analyser.getFloatTimeDomainData(dataArray); // Collect time-domain data
          audioDataBufferRef.current.push(...dataArray); // Append to audio buffer

          if (isPredicting) {
            predictLiveAudio(); // Immediately predict on new data
          }
        };

        const intervalId = setInterval(collectAudioData, 100); // Collect every 100ms
        return () => clearInterval(intervalId); // Cleanup interval on component unmount
      }
    };

    processAudio();
  }, [isPredicting]);

  // Toggle prediction state
  const toggleAudioPrediction = () => {
    setIsPredicting((prevState) => !prevState); // Toggle prediction state
  };

  const addClass = () => {
    const newClassId = classes.length + 1; // Increment class ID for the new class
    const colors = ['#e57373', '#81c784', '#64b5f6', '#ffd54f', '#a1887f']; // Array of colors for class visualization
    setClasses([...classes, {
      id: newClassId,
      name: `Class ${newClassId}`, // Assign a default name for the new class
      color: colors[newClassId % colors.length], // Cycle through predefined colors
      audio: [] // Initialize an empty array for audio samples
    }]);
  };

  return (
    <div className="bg-gray-100 min-h-screen">
      <header className="bg-blue-800">
        <nav className="h-16"></nav>
      </header>
      <div className="container mx-auto mt-5">
        {/* Add Class Section */}
        <div className="bg-white p-5 mb-5 shadow">
          <div className="flex items-center space-x-4">
            <h1 className="text-lg font-semibold">{type}</h1>
            <button
              className="bg-blue-600 text-white px-4 py-2 rounded hover:bg-blue-700"
              onClick={addClass}
            >
              Add Class
            </button>
          </div>
        </div>

        {/* Classes Grid */}
        <div className="grid grid-cols-1 sm:grid-cols-2 gap-5">
          {classes.map((cls) => (
            <div key={cls.id} className={`p-5 shadow rounded`} style={{ backgroundColor: cls.color }}>
              <h2 className="text-white text-lg mb-3">{cls.name}</h2>
              <div className="flex items-center space-x-3">
                <label className="bg-white text-sm font-medium text-gray-800 px-4 py-2 rounded cursor-pointer hover:bg-gray-100">
                  <span>Upload Audio</span>
                  <input
                    type="file"
                    accept="audio/*"
                    hidden
                    onChange={(e) => handleAudioUpload(e, cls.id)}
                  />
                </label>
                <button
                  className="bg-white text-sm font-medium text-gray-800 px-4 py-2 rounded hover:bg-gray-100"
                  onClick={() => recordAudio(cls.id)}
                >
                  Record Audio
                </button>
              </div>

              {/* Recorded Audio */}
              <div className="mt-5">
                <h3 className="text-white font-medium">Recorded Audio:</h3>
                <div className="flex flex-wrap space-x-3 mt-3">
                  {cls.audio.map((audioBlob, index) => (
                    <audio key={index} controls className="w-full sm:w-auto">
                      <source src={URL.createObjectURL(audioBlob)} type="audio/wav" />
                    </audio>
                  ))}
                </div>
              </div>
            </div>
          ))}
        </div>

        {/* Training Section */}
        <div className="bg-white p-5 mt-5 shadow text-center">
          <h2 className="text-lg font-semibold">Training</h2>
          <button
            className={`mt-4 px-5 py-2 rounded ${isTraining || isModelTrained
                ? 'bg-gray-400 text-gray-700 cursor-not-allowed'
                : 'bg-red-600 text-white hover:bg-red-700'
              }`}
            onClick={handleTraining}
            disabled={isTraining || isModelTrained}
          >
            {isTraining ? (
              <div className="flex justify-center">
                <div className="animate-spin w-5 h-5 border-2 border-t-transparent border-white rounded-full"></div>
              </div>
            ) : (
              'Train Model'
            )}
          </button>
        </div>

        {/* Testing Section */}
        <div className="bg-white p-5 mt-5 shadow">
          <h2 className="text-lg font-semibold">Test the Model</h2>
          <button
            className="mt-4 px-5 py-2 bg-blue-600 text-white rounded hover:bg-blue-700"
            onClick={toggleAudioPrediction}
          >
            {isPredicting ? 'Stop Prediction' : 'Start Prediction'}
          </button>
        </div>

        {/* Prediction Results */}
        {predictionResult && (
          <div className="bg-white p-5 mt-5 shadow">
            <h2 className="text-lg font-semibold">Prediction Result:</h2>
            <p>Class Index: {predictionResult.classIndex}</p>
            <p>Label: {predictionResult.label}</p>
            <p>Confidence: {JSON.stringify(predictionResult.confidences)}</p>
          </div>
        )}
      </div>
    </div>
  );
};

export default AudioClassifier;
