
import React, { useState, useRef, useEffect, useCallback } from 'react';
import { Container, Grid, Paper,LinearProgress, Button, Typography, IconButton, CircularProgress, AppBar, Toolbar } from '@mui/material';
import CameraAltIcon from '@mui/icons-material/CameraAlt';

import UploadIcon from '@mui/icons-material/Upload';
import * as handpose from '@tensorflow-models/handpose';
import AddIcon from '@mui/icons-material/Add';

import ArrowBackIcon from '@mui/icons-material/ArrowBack';
import Webcam from 'react-webcam';
import * as tf from '@tensorflow/tfjs';
import * as posenet from '@tensorflow-models/posenet';
import * as knnClassifier from '@tensorflow-models/knn-classifier';
import html2canvas from 'html2canvas';
import { useNavigate } from 'react-router-dom';
import * as cocossd from '@tensorflow-models/coco-ssd';
import mlImage from './image.png';

const PoseClassifier = ({ name, description, type }) => {
  const [classes, setClasses] = useState([
    { id: 1, name: 'Class 1', color: '#df4ef2', images: [] },
    { id: 2, name: 'Class 2', color: '#89f582', images: [] }
  ]);
  const [isPredicting, setIsPredicting] = useState(false);
  const [isTraining, setIsTraining] = useState(false);
  const [isModelTrained, setIsModelTrained] = useState(false);
  const [classifier, setClassifier] = useState(knnClassifier.create());
  const [activeWebcam, setActiveWebcam] = useState(null);
  const [net, setNet] = useState(null); // PoseNet model
  const [capturedPoses, setCapturedPoses] = useState([]);
  const [imageSrc, setImageSrc] = useState('');
  const [predictionResult, setPredictionResult] = useState('');
  const navigate = useNavigate(); 
  const [model, setModel] = useState(null);
  const [isLoading, setIsLoading] = useState(false); // Loading state for webcam

  const webcamRef = useRef(null);
  const canvasRef = useRef(null);

  // Load the HandPose model
  useEffect(() => {
    const loadModel = async () => {
      const loadedModel = await handpose.load(); // Load the HandPose model
      setNet(loadedModel); // Set the model
    };
    loadModel(); // Load the model on component mount
  }, []);

  // Toggle webcam for a specific class
  const toggleWebcamForClass = (classId) => {
    if (activeWebcam === classId) {
      setActiveWebcam(null);
    } else {
      setActiveWebcam(classId);
    }
  };

  const handleBackClick = () => {
    navigate('../'); // Navigate back to the previous page
  };

  // Draw hand detections (landmarks)
  const FINGER_CONNECTIONS = [
    [0, 1], [1, 2], [2, 3], [3, 4], // Thumb
    [0, 5], [5, 6], [6, 7], [7, 8], // Index finger
    [5, 9], [9, 10], [10, 11], [11, 12], // Middle finger
    [9, 13], [13, 14], [14, 15], [15, 16], // Ring finger
    [13, 17], [17, 18], [18, 19], [19, 20] // Pinky finger
  ];

  const drawHandDetections = (predictions, videoWidth, videoHeight) => {
    const ctx = canvasRef.current.getContext('2d');
    ctx.clearRect(0, 0, canvasRef.current.width, canvasRef.current.height);

    // Adjust the scaling based on video and canvas dimensions
    const scaleX = canvasRef.current.width / videoWidth;
    const scaleY = canvasRef.current.height / videoHeight;

    predictions.forEach(prediction => {
      const landmarks = prediction.landmarks;

      // Draw keypoints with correct scaling and reduced size
      landmarks.forEach(([x, y]) => {
        ctx.beginPath();
        ctx.arc(x * scaleX, y * scaleY, 3, 0, 2 * Math.PI); // Reduced keypoint size
        ctx.fillStyle = "red";
        ctx.fill();
      });

      // Draw skeleton (lines between keypoints) with correct scaling
      FINGER_CONNECTIONS.forEach(([startIdx, endIdx]) => {
        const start = landmarks[startIdx];
        const end = landmarks[endIdx];
        ctx.beginPath();
        ctx.moveTo(start[0] * scaleX, start[1] * scaleY);
        ctx.lineTo(end[0] * scaleX, end[1] * scaleY);
        ctx.strokeStyle = "green";
        ctx.lineWidth = 2;
        ctx.stroke();
      });
    });
  };

  // Capture and draw hand poses on the canvas
  useEffect(() => {
    const detectHands = async () => {
      if (webcamRef.current && net) {
        const video = webcamRef.current.video;
        if (video.readyState === 4) {
          const videoWidth = video.videoWidth;
          const videoHeight = video.videoHeight;
          webcamRef.current.video.width = 300;
          webcamRef.current.video.height = 300;

          const canvas = canvasRef.current;
          const canvasWidth = 300;
          const canvasHeight = 400;

          if (canvas) {
            const ctx = canvas.getContext('2d');
            canvas.width = canvasWidth;
            canvas.height = canvasHeight;

            // Get predictions from the model
            const predictions = await net.estimateHands(video);

            // Draw the video frame onto the canvas
            ctx.drawImage(video, 0, 0, canvasWidth, canvasHeight);

            // Draw hand detections on the canvas
            drawHandDetections(predictions, videoWidth, videoHeight);
          }
        }
      }
      requestAnimationFrame(detectHands);
    };
    detectHands(); // Start hand detection loop
  }, [net]);

  // Updated capture image function to ensure webcam video and landmarks are drawn before capture
  const addImageToClassifier = (imageSrc, classId) => {
    if (!classifier.current) return; // Check if classifier is initialized
    const imgElement = new Image();
    imgElement.src = imageSrc;
    imgElement.onload = () => {
      const imgTensor = tf.browser.fromPixels(imgElement)
        .resizeBilinear([224, 224])
        .toFloat()
        .expandDims(0);

      // Ensure classifier is defined before adding example
      if (classifier.current) {
        classifier.current.addExample(imgTensor, classId - 1);
      }
    };
    imgElement.onerror = () => {
      console.error('Error loading image for classification');
    };
  };

  const isWebcamReady = () => {
    return (
      webcamRef.current &&
      webcamRef.current.video &&
      webcamRef.current.video.readyState === 4
    );
  };

  const captureImage = useCallback(async (classId) => {
    if (canvasRef.current && webcamRef.current && net) {
      const canvas = canvasRef.current;
      const ctx = canvas.getContext('2d');
      const video = webcamRef.current.video;

      // Check if the webcam video is ready
      if (isWebcamReady()) {
        // Set desired dimensions for the captured image
        const desiredWidth = 320; // Adjust this width as needed
        const desiredHeight = 240; // Adjust this height as needed

        // Set canvas dimensions to match the desired size
        canvas.width = desiredWidth;
        canvas.height = desiredHeight;

        // First, draw the current webcam video frame onto the canvas
        ctx.drawImage(video, 0, 0, desiredWidth, desiredHeight);

        // Perform hand pose detection using the handpose model
        const predictions = await net.estimateHands(video);

        // If hand is detected, draw the keypoints on the canvas
        predictions.forEach(prediction => {
          const landmarks = prediction.landmarks;

          // Draw keypoints (landmarks) on the canvas
          ctx.fillStyle = 'red';
          ctx.strokeStyle = 'red';
          ctx.lineWidth = 2;

          // Loop through landmarks and draw them on top of the video feed
          landmarks.forEach(([x, y]) => {
            // Adjust x and y coordinates based on the scaling of the video and canvas
            const scaleX = canvas.width / video.videoWidth;
            const scaleY = canvas.height / video.videoHeight;

            const adjustedX = x * scaleX;
            const adjustedY = y * scaleY;

            ctx.beginPath();
            ctx.arc(adjustedX, adjustedY, 5, 0, 2 * Math.PI);  // Draw small circles for keypoints
            ctx.fill();
          });
        });

        // Capture the canvas as an image with the webcam video and keypoints
        html2canvas(canvas).then(capturedCanvas => {
          // Compress image to reduce size (JPEG with 50% quality)
          const imageSrc = capturedCanvas.toDataURL('image/jpeg', 0.5);  // Use JPEG with quality 50%

          if (imageSrc) {
            // Limit to a maximum of 20 images in the class
            setClasses(prevClasses =>
              prevClasses.map(cls => {
                if (cls.id === classId) {
                  const updatedImages = [...cls.images, imageSrc];

                  // Keep only the latest 20 images
                  if (updatedImages.length > 20) {
                    updatedImages.splice(0, updatedImages.length - 20);  // Remove excess images
                  }

                  return { ...cls, images: updatedImages };
                }
                return cls;
              })
            );
            addImageToClassifier(imageSrc, classId);  // Add image to classifier
          } else {
            console.error('Canvas did not return any image');
          }
        }).catch(err => console.error('Error capturing image:', err));
      } else {
        console.error("Webcam is not ready.");
      }
    }
  }, [canvasRef, webcamRef, net]);

  // Handle image upload
  const handleImageUpload = (event, classId) => {
    const imageFiles = event.target.files;
    if (imageFiles && imageFiles.length > 0) {
      Array.from(imageFiles).forEach((file) => {
        const reader = new FileReader();
        reader.onload = () => {
          const imageSrc = reader.result;
          // Limit to a maximum of 20 images in the class
          setClasses((prevClasses) =>
            prevClasses.map((cls) => {
              if (cls.id === classId) {
                const updatedImages = [...cls.images, imageSrc];

                // Keep only the latest 20 images
                if (updatedImages.length > 20) {
                  updatedImages.splice(0, updatedImages.length - 20); // Remove excess images
                }

                return { ...cls, images: updatedImages };
              }
              return cls;
            })
          );
          addImageToClassifier(imageSrc, classId); // Add image to classifier
        };
        reader.readAsDataURL(file); // Convert image file to Base64
      });
    }
  };
  const handleTraining = async () => {
    setIsTraining(true);

    classes.forEach(cls => {
      cls.images.forEach(imageSrc => {
        const imgElement = new Image();
        imgElement.src = imageSrc;
        imgElement.onload = () => {
          const imgTensor = tf.browser.fromPixels(imgElement)
            .resizeBilinear([224, 224])
            .toFloat()
            .expandDims(0);
          classifier.addExample(imgTensor, cls.id - 1);
        };
      });
    });

    await new Promise(resolve => setTimeout(resolve, 3000));
    setIsTraining(false);
    setIsModelTrained(true);
    alert('Training Complete!');
  };
  // const predictFromCanvas = async () => {
  //   if (canvasRef.current && isPredicting && classifier?.getNumClasses() > 0) {
  //     const ctx = canvasRef.current.getContext('2d');
  //     const imgData = ctx.getImageData(0, 0, canvasRef.current.width, canvasRef.current.height);
  //     const imgTensor = tf.browser.fromPixels(imgData)
  //       .resizeBilinear([224, 224]) // Resize to 224x224
  //       .toFloat()
  //       .expandDims(0); // Add batch dimension
  
  //     try {
  //       const prediction = await classifier.predictClass(imgTensor);
  //       console.log("Prediction result:", prediction); // Log the full result for debugging
  
  //       // Adjust the class index and label to start from 1 instead of 0
  //       const adjustedClassIndex = prediction.classIndex + 1;
  //       const adjustedLabel = parseInt(prediction.label) + 1; // Convert label to number and increment by 1
  
  //       // Set prediction result with adjusted class index and label
  //       setPredictionResult({
  //         classIndex: adjustedClassIndex,
  //         label: adjustedLabel.toString(), // Convert back to string if necessary
  //         confidences: prediction.confidences,
  //       });
  //     } catch (error) {
  //       console.error("Prediction error: ", error);
  //       setPredictionResult({
  //         classIndex: null,
  //         label: 'Prediction failed',
  //         confidences: {},
  //       });
  //     } finally {
  //       imgTensor.dispose(); // Dispose tensor to free memory
  //     }
      
  //     // Continue predicting with the next frame
  //     requestAnimationFrame(predictFromCanvas);
  //   }
  // };
  
  const predictFromCanvas = async () => {
    if (canvasRef.current && isPredicting && classifier?.getNumClasses() > 0) {
      // Throttle prediction loop to reduce load
      setTimeout(async () => {
        try {
          const prediction = await tf.tidy(() => {
            const ctx = canvasRef.current.getContext('2d');
            const imgData = ctx.getImageData(0, 0, canvasRef.current.width, canvasRef.current.height);
  
            // Create the tensor from the image data
            const imgTensor = tf.browser.fromPixels(imgData)
              .resizeBilinear([224, 224]) // Resize to 224x224
              .toFloat()
              .expandDims(0); // Add batch dimension
  
            // Perform the prediction inside tf.tidy
            return classifier.predictClass(imgTensor);
          });
  
          console.log("Prediction result:", prediction); // Log the full result for debugging
  
          // Use the class index and label directly
          const predictedClassIndex = prediction.classIndex;
          const predictedLabel = prediction.label;
  
          // Debugging: Check which class corresponds to what label
          console.log(`Predicted Class Index: ${predictedClassIndex}, Label: ${predictedLabel}`);
          
          // Ensure correct label is being predicted
          const labelMapping = {
            0: 'Class 1', // Change this mapping if class 0 is for class 1 images
            1: 'Class 2', // Change this mapping if class 1 is for class 2 images
          };
  
          // Check if the predicted classIndex matches the correct label
          const correctedLabel = labelMapping[predictedClassIndex] || 'Unknown Class';
  
          // Set prediction result with the corrected label
          setPredictionResult({
            classIndex: predictedClassIndex,
            label: correctedLabel, // Use the corrected label mapping
            confidences: prediction.confidences,
          });
  
        } catch (error) {
          console.error("Prediction error: ", error);
          setPredictionResult({
            classIndex: null,
            label: 'Prediction failed',
            confidences: {},
          });
        }
  
        // Continue predicting, but throttle the loop by a small delay
        requestAnimationFrame(predictFromCanvas);
      }, 100); // Throttle prediction to run every 100ms
    }
  };
  

  const toggleWebcamForPrediction = () => {
    setIsPredicting(prev => !prev);
    setActiveWebcam(isPredicting ? null : 'prediction');
    setPredictionResult('');
  };
  const addClass = () => {
    const newClassId = classes.length + 1;
    const colors = ['#34ebba','#83f29d','#b583f2','#f2839d','#d7fc5b'];
    setClasses([...classes, {
      id: newClassId,
      name: `class${newClassId}`,
      color: colors[newClassId % colors.length],
      images: []
    }]);
  };

  useEffect(() => {
    let interval;

    if (isPredicting) {
      interval = setInterval(() => {
        predictFromCanvas();
      }, 1000);
    } else {
      clearInterval(interval);
    }

    return () => clearInterval(interval);
  }, [isPredicting, classifier]);

  return (  <div className="min-h-screen bg-gray-50"
    style={{
      backgroundImage: `url(${mlImage})`,
      backgroundSize: 'cover',
      backgroundPosition: 'center',
      backgroundRepeat: 'no-repeat',
    }}
  >
    {/* <video
        autoPlay
        loop
        muted
        className="absolute inset-0 w-full h-full object-cover z-0"
      >
        <source src={require('./Video/3130284-uhd_3840_2160_30fps.mp4')} type="video/mp4" />
        Your browser does not support the video tag.
      </video> */}
    {/* Header */}
    {/* <div className="bg-blue-600 text-white">
      <div className="max-w-7xl mx-auto px-4">
        <div className="flex items-center h-16">
          <IconButton
            onClick={handleBackClick}
            className="text-white hover:bg-blue-700"
          >
            <ArrowBackIcon />
          </IconButton>
          <Typography variant="h6" className="ml-4">
            {name}  Hand Pose Detection Model
          </Typography>
        </div>
      </div>
    </div> */}

    <Container className="py-8">
      <div className="grid grid-cols-1 lg:grid-cols-3 gap-8">
        {/* Training Section */}
        <div className="lg:col-span-2">
          <Paper className="p-6 mb-6" style={{ backgroundColor: 'transparent', boxShadow: 'none' }}>
            <div className="flex justify-between items-center mb-6">
              <Typography variant="h5" className="font-medium">
                Training
              </Typography>
              <Button
                onClick={addClass}
                variant="outlined"
                startIcon={<AddIcon />}
                style={{
                  backgroundColor: '#007BFF', // Background color of the button
                  color: 'white',              // Text color
                  borderRadius: '25px',        // Set border radius for round shape
                  padding: '10px 20px',        // Adjust padding for better sizing
                }}
              >
                Add New Class
              </Button>


            </div>

            <div className="flex flex-col items-center space-y-6">
              {classes.map((cls) => (
                <Paper key={cls.id} elevation={1} className="overflow-hidden" style={{ width: '400px', backgroundColor: 'transparent', boxShadow: 'none' }}>

                  <div
                    className="p-4"
                    style={{ backgroundColor: cls.color }}
                  >
                    <div className="flex justify-between items-center">
                      <Typography className="text-white font-medium">
                        {cls.name}
                      </Typography>
                      <div className="flex gap-2">
                        <Button
                          variant="contained"
                          component="label"
                          className="bg-white hover:bg-gray-100"
                          style={{
                            color: cls.color, borderRadius: '25px',        // Set border radius for round shape
                            padding: '10px 20px',
                          }}
                          startIcon={<UploadIcon />}
                        >
                          Upload
                          <input
                            type="file"
                            hidden
                            accept="image/*"
                            onChange={(e) => handleImageUpload(e, cls.id)}
                          />
                        </Button>
                        <Button
                          variant="contained"
                          className="bg-white hover:bg-gray-100"
                          style={{
                            color: cls.color,
                            borderRadius: '25px',  // Set border radius for round shape
                            padding: '10px 20px',
                          }}
                          startIcon={<CameraAltIcon />}
                          onClick={() => toggleWebcamForClass(cls.id)}
                        >
                          {activeWebcam === cls.id ? 'Stop' : 'Start'}
                        </Button>
                      </div>
                    </div>

                    {/* Webcam Section */}
                    {activeWebcam === cls.id && (
                      <div className="mt-4 flex flex-col items-center">
                        <div className="relative" style={{ width: '300px', height: '300px' }}>
                          {isLoading ? (
                            // Loading Design (spinner) before the webcam opens
                            <div
                              className="flex flex-col items-center justify-center"
                              style={{ width: '300px', height: '300px' }}
                            >
                              <CircularProgress /> {/* Spinner as loading indicator */}
                              <Typography variant="body2" className="mt-2">
                                Opening Camera...
                              </Typography>
                            </div>
                          ) : (
                            <>
                              <Webcam
                                audio={false}
                                ref={webcamRef}
                                screenshotFormat="image/jpeg"
                                style={{
                                  width: '300px',
                                  height: '300px',
                                  position: 'absolute',
                                  top: 0,
                                  left: 0,
                                }}
                                videoConstraints={{
                                  width: 300,
                                  height: 300,
                                  facingMode: 'user',
                                }}
                              />
                              <canvas
                                ref={canvasRef}
                                style={{
                                  width: '300px',
                                  height: '300px',
                                  position: 'absolute',
                                  top: 0,
                                  left: 0,
                                }}
                              />
                            </>
                          )}
                        </div>

                        <Button
                          variant="contained"
                          className="mt-4 bg-white hover:bg-gray-100"
                          style={{
                            color: cls.color,
                            borderRadius: '25px',  // Set border radius for round shape
                            padding: '10px 20px',
                          }}
                          onClick={() => captureImage(cls.id)}
                        >
                          Capture
                        </Button>
                      </div>
                    )}


                    {/* Image Gallery */}
                    {cls.images.length > 0 && (
                      <div className="mt-4 flex flex-wrap gap-2 justify-center">
                        {cls.images.map((image, index) => (
                          <img
                            key={index}
                            src={image}
                            alt={`Sample ${index + 1}`}
                            style={{
                              width: '50px',
                              height: '50px',
                              objectFit: 'cover'
                            }}
                            className="rounded"
                          />
                        ))}
                      </div>
                    )}
                  </div>
                </Paper>
              ))}
            </div>
          </Paper>
        </div>

        {/* Preview Section */}
        <div className="lg:col-span-1">
          <Paper className="p-6 sticky top-6">
            <Typography variant="h5" className="font-medium mb-6">
              Preview
            </Typography>

            <div className="space-y-4">
              <Button
                variant="contained"
                color="primary"
                fullWidth
                style={{
                  borderRadius: '25px',        // this line is for making the shape of the button round shape 
                  padding: '10px 20px',
                }}
                onClick={handleTraining}
                disabled={isTraining || isModelTrained}
                className="h-12"
              >
                {isTraining ? (
                  <CircularProgress size={24} color="inherit" />
                ) : (
                  'Train Model'
                )}
              </Button>

              <Button
                variant="outlined"
                color="primary"
                fullWidth
                onClick={toggleWebcamForPrediction}
                className="h-12"
              >
                {isPredicting ? 'Stop Preview' : 'Start Preview'}
              </Button>

              {isPredicting && (
                <div className="mt-6">
                  <div className="flex justify-center">
                    <div className="relative" style={{ width: '300px', height: '300px' }}>
                      <Webcam
                        audio={false}
                        ref={webcamRef}
                        screenshotFormat="image/jpeg"
                        style={{
                          width: '300px',
                          height: '300px',
                          position: 'absolute',
                          top: 0,
                          left: 0,
                        }}
                        videoConstraints={{
                          width: 300,
                          height: 300,
                          facingMode: 'user',
                        }}
                      />
                      <canvas
                        ref={canvasRef}
                        style={{
                          width: '300px',
                          height: '300px',
                          position: 'absolute',
                          top: 0,
                          left: 0,
                        }}
                      />
                    </div>
                  </div>


                  <Paper className="mt-4 p-4 bg-gray-50">
                    <Typography variant="subtitle1" className="font-medium">
                      Prediction Results
                    </Typography>
                    <div className="mt-2 space-y-4">
                      {predictionResult.confidences &&
                        Object.entries(predictionResult.confidences).map(([key, value]) => (
                          <div key={key} className="space-y-1">
                            <Typography variant="body2" className="text-gray-600">
                              Class: {key}
                            </Typography>
                            <LinearProgress
                              variant="determinate"
                              value={value * 100}
                              style={{ height: 10, borderRadius: 5 }}
                            />
                            <Typography variant="body2" className="text-gray-600">
                              Confidence: {(value * 100).toFixed(1)}%
                            </Typography>
                          </div>
                        ))
                      }
                    </div>
                  </Paper>
                </div>
              )}
            </div>
          </Paper>
        </div>
      </div>
    </Container>
  </div>

  );
};

export default PoseClassifier;



