import React, { useEffect, useState, useRef } from 'react';
import { Container, Grid, Paper, Button, Typography, IconButton, CircularProgress, AppBar, Toolbar } from '@mui/material';

import UploadIcon from '@mui/icons-material/Upload';

import ArrowBackIcon from '@mui/icons-material/ArrowBack';
import * as tf from '@tensorflow/tfjs';
import * as knnClassifier from '@tensorflow-models/knn-classifier'; // Updated import
import { useNavigate } from 'react-router-dom';                       

const TARGET_LENGTH = 16000; // Example target length of audio clip
const REQUIRED_SIZE = 16000;

const AudioClassifier = ({ name, description, type }) => {
  const [classes, setClasses] = useState([
    { id: 1, name: 'Class 1', color: '#e57373', audio: [] },
    { id: 2, name: 'Class 2', color: '#81c784', audio: [] }
  ]);
  const [isPredicting, setIsPredicting] = useState(false);
  const [isTraining, setIsTraining] = useState(false);
  const [isModelTrained, setIsModelTrained] = useState(false);
  const [classifier, setClassifier] = useState(knnClassifier.create());
  const [predictionResult, setPredictionResult] = useState(null);
  const audioRef = useRef(null);
  const audioContextRef = useRef(null);
  const microphoneRef = useRef(null);
  const audioDataBufferRef = useRef([]); // Store audio data for prediction
  const navigate = useNavigate();

  // Initialize audio context and microphone
  useEffect(() => {
    const initAudio = async () => {
      audioContextRef.current = new (window.AudioContext || window.webkitAudioContext)();
      microphoneRef.current = await navigator.mediaDevices.getUserMedia({ audio: true });
      audioRef.current = audioContextRef.current.createMediaStreamSource(microphoneRef.current);
      audioContextRef.current.resume(); // Resume the audio context after user gesture
    };

    initAudio();
  }, []);

  const handleBackClick = () => {
    navigate('../'); // Navigate back to the previous page
  };

  // Preprocess audio buffer
  const preprocessAudioBuffer = (audioBuffer) => {
    const channelData = audioBuffer.getChannelData(0);
    if (channelData.length > TARGET_LENGTH) {
      return channelData.slice(0, TARGET_LENGTH);
    } else {
      const paddedAudio = new Float32Array(TARGET_LENGTH);
      paddedAudio.set(channelData);
      return paddedAudio;
    }
  };

  // Handle audio upload
  const handleAudioUpload = (e, classId) => {
    const file = e.target.files[0];
    if (file) {
      const audioBlob = new Blob([file], { type: 'audio/wav' });
      setClasses(prevClasses =>
        prevClasses.map(cls =>
          cls.id === classId
            ? { ...cls, audio: [...cls.audio, audioBlob] }
            : cls
        )
      );
    }
  };

  // Add audio to classifier
  const addAudioToClassifier = (audioBlob, classId) => {
    const reader = new FileReader();
    reader.onloadend = async () => {
      const audioBuffer = await audioContextRef.current.decodeAudioData(reader.result);
      const processedAudio = preprocessAudioBuffer(audioBuffer);
      const audioTensor = tf.tensor(processedAudio).expandDims(0); // Shape: [1, TARGET_LENGTH]
      classifier.addExample(audioTensor, classId - 1);  // KNN class starts from 0
      setClasses(prevClasses =>
        prevClasses.map(cls =>
          cls.id === classId ? { ...cls, audio: [...cls.audio, audioBlob] } : cls
        )
      );
    };
    reader.readAsArrayBuffer(audioBlob);
  };

  // Record audio for training
  const recordAudio = async (classId) => {
    const mediaRecorder = new MediaRecorder(microphoneRef.current);
    const audioChunks = [];

    mediaRecorder.ondataavailable = event => {
      audioChunks.push(event.data);
    };

    mediaRecorder.onstop = () => {
      const audioBlob = new Blob(audioChunks, { type: 'audio/wav' });
      addAudioToClassifier(audioBlob, classId);
      alert(`Audio recorded for ${classes.find(cls => cls.id === classId).name}`);
    };

    mediaRecorder.start();
    setTimeout(() => {
      mediaRecorder.stop();
    }, 3000); // Adjust recording duration as needed
  };

  // Train the model
  const handleTraining = async () => {
    setIsTraining(true);
    await new Promise(resolve => setTimeout(resolve, 3000));
    setIsTraining(false);
    setIsModelTrained(true);
    alert('Training Complete!');
  };

  // Predict live audio
  const predictLiveAudio = async () => {
    const audioDataBuffer = audioDataBufferRef.current;
  
    if (classifier && audioDataBuffer.length >= REQUIRED_SIZE) {
      const trimmedData = audioDataBuffer.slice(-REQUIRED_SIZE); // Use the last REQUIRED_SIZE elements
      const audioTensor = tf.tensor(trimmedData).reshape([1, REQUIRED_SIZE]); // Reshape for model input
  
      try {
        const prediction = await classifier.predictClass(audioTensor);
        
        // Dynamically handle number of classes by using the size of the confidences array
        const confidencesArray = Object.values(prediction.confidences);
        
        // Set a confidence threshold (if needed)
        const confidenceThreshold = 0.5;
  
        // Filter and find the class with the highest confidence
        const validConfidences = confidencesArray.map((conf, index) => {
          return conf >= confidenceThreshold ? { index, conf } : { index, conf: 0 };
        });
  
        // Find the class index with the highest confidence
        const predictedClass = validConfidences.reduce((prev, curr) => {
          return (curr.conf > prev.conf) ? curr : prev;
        });
  
        if (predictedClass && predictedClass.conf > 0) {
          const predictedClassIndex = predictedClass.index;
          
          // Dynamically set the class label based on classIndex + 1 (or however your labels are structured)
          setPredictionResult({
            classIndex: predictedClassIndex + 1,  // Adjust the class index (if labels start from 1)
            label: `Class ${predictedClassIndex + 1}`,  // Adjust the label
            confidences: prediction.confidences,  // Store the confidence values
          });
        } else {
          setPredictionResult({
            classIndex: null,
            label: 'Prediction uncertain',
            confidences: prediction.confidences,
          });
        }
  
      } catch (error) {
        console.error("Prediction error: ", error);
        setPredictionResult({
          classIndex: null,
          label: 'Prediction failed',
          confidences: {},
        });
      } finally {
        audioTensor.dispose(); // Dispose the tensor after prediction
      }
  
      audioDataBufferRef.current = []; // Reset buffer after prediction
    } else {
      console.log("Not enough audio data for prediction: ", audioDataBuffer.length);
    }
  };
  

  useEffect(() => {
    const processAudio = () => {
      if (audioRef.current && audioContextRef.current) {
        const analyser = audioContextRef.current.createAnalyser();
        audioRef.current.connect(analyser);
        analyser.fftSize = 2048;
        const bufferLength = analyser.frequencyBinCount;
        const dataArray = new Float32Array(bufferLength);
  
        const collectAudioData = () => {
          analyser.getFloatTimeDomainData(dataArray); // Collect time-domain data
          audioDataBufferRef.current.push(...dataArray); // Append to audio buffer
  
          if (isPredicting) {
            predictLiveAudio(); // Immediately predict on new data
          }
        };
  
        const intervalId = setInterval(collectAudioData, 100); // Collect every 100ms
        return () => clearInterval(intervalId); // Cleanup interval on component unmount
      }
    };
  
    processAudio();
  }, [isPredicting]);
  
  // Toggle prediction state
  const toggleAudioPrediction = () => {
    setIsPredicting((prevState) => !prevState); // Toggle prediction state
  };
  
  const addClass = () => {
    const newClassId = classes.length + 1; // Increment class ID for the new class
    const colors = ['#e57373', '#81c784', '#64b5f6', '#ffd54f', '#a1887f']; // Array of colors for class visualization
    setClasses([...classes, {
      id: newClassId,
      name: `Class ${newClassId}`, // Assign a default name for the new class
      color: colors[newClassId % colors.length], // Cycle through predefined colors
      audio: [] // Initialize an empty array for audio samples
    }]);
  };
  
  return (
<div style={{ backgroundColor: '#f0f0f0', minHeight: '100vh' }}>
  <AppBar position="static" style={{ backgroundColor: '#003AD4' }}>
    {/* <Toolbar>
      <Button color="inherit" startIcon={<ArrowBackIcon />} onClick={handleBackClick}></Button>
      <Typography variant="h6" style={{ flexGrow: 1 }}>
        StemVerse Machine Learning Environment - Audio Classifier{name}
      </Typography>
    </Toolbar> */}
  </AppBar>
  <Container maxWidth="lg" style={{ marginTop: '20px' }}>
    <Paper style={{ padding: '20px', marginBottom: '20px' }}>
      <Grid container spacing={2} alignItems="center">
        <Grid item>
          <Typography variant="h6">{type}</Typography>
        </Grid>
        
        <Grid item>
          <Button variant="contained" color="primary" onClick={addClass}>
            Add Class
          </Button>
        </Grid>
      </Grid>
    </Paper>

    <Grid container spacing={3}>
      {classes.map((cls) => (
        <Grid item xs={12} sm={6} key={cls.id}>
          <Paper style={{ padding: '20px', backgroundColor: cls.color }}>
            <Typography variant="h6" style={{ color: 'white', marginBottom: '10px' }}>{cls.name}</Typography>
            <Button
              variant="contained"
              component="label"
              startIcon={<UploadIcon />}
              style={{ backgroundColor: '#ffffff', marginRight: '10px', color: cls.color }}
            >
              Upload Audio
              <input type="file" accept="audio/*" hidden onChange={e => handleAudioUpload(e, cls.id)} />
            </Button>
            <Button
              variant="contained"
              onClick={() => recordAudio(cls.id)}
              style={{ backgroundColor: '#ffffff', color: cls.color }}
            >
              Record Audio
            </Button>

            <div style={{ display: 'flex', flexWrap: 'wrap', marginTop: '20px' }}>
              <h3>Recorded Audio:</h3>
              {cls.audio.map((audioBlob, index) => (
                <audio key={index} controls>
                  <source src={URL.createObjectURL(audioBlob)} type="audio/wav" />
                </audio>
              ))}
            </div>
          </Paper>
        </Grid>
      ))}
    </Grid>

    <Paper style={{ padding: '20px', marginTop: '20px', textAlign: 'center' }}>
      <Typography variant="h6">Training</Typography>
      <Button
        variant="contained"
        color="secondary"
        onClick={handleTraining}
        disabled={isTraining || isModelTrained}
      >
        {isTraining ? <CircularProgress size={24} color="inherit" /> : 'Train Model'}
      </Button>
    </Paper>

    <Paper style={{ padding: '20px', marginTop: '20px' }}>
      <Typography variant="h6">Test the Model</Typography>
      <Button
        variant="contained"
        color="primary"
        onClick={toggleAudioPrediction}
      >
        {isPredicting ? 'Stop Prediction' : 'Start Prediction'}
      </Button>
    </Paper>

    {predictionResult && (
      <Paper style={{ padding: '20px', marginTop: '20px' }}>
        <Typography variant="h6">Prediction Result:</Typography>
        <Typography variant="body1">Class Index: {predictionResult.classIndex}</Typography>
        <Typography variant="body1">Label: {predictionResult.label}</Typography>
        <Typography variant="body1">
          Confidence: {JSON.stringify(predictionResult.confidences)}
        </Typography>
      </Paper>
    )}
  </Container>
</div>

  );
};

export default AudioClassifier;
