#!/usr/bin/env python3
"""STT Streaming Client with Voice Activity Detection (VAD).

Connects to stt.mm.mk via WebSocket, streams microphone audio with Silero VAD,
and prints real-time transcription. Designed for Raspberry Pi / Linux.

Usage:
    pip install websocket-client numpy sounddevice torch
    # Optional: pip install silero-vad  (or it auto-downloads)

    python3 stt_client.py                           # default mic, Polish
    python3 stt_client.py --language en              # English
    python3 stt_client.py --language pl --device 2   # specific mic device
    python3 stt_client.py --stream-id kitchen        # name this stream
    python3 stt_client.py --list-devices             # show available mics
    python3 stt_client.py --no-vad                   # disable VAD, send everything
    python3 stt_client.py --server wss://stt.mm.mk   # custom server

Features:
    - Silero VAD filters silence, only sends speech to server
    - Sends raw PCM int16 mono 16kHz over WebSocket binary frames
    - Receives JSON transcription results in real-time
    - Auto-reconnects on disconnect
    - Low CPU usage on Pi (VAD is lightweight)
"""

import argparse
import json
import sys
import time
import struct
import threading
import signal
import os

import numpy as np

# --- Configuration ---
DEFAULT_SERVER = "wss://stt.mm.mk"
SAMPLE_RATE = 16000
CHANNELS = 1
CHUNK_DURATION_MS = 512  # VAD window size (must be 256/512 for Silero)
CHUNK_SAMPLES = int(SAMPLE_RATE * CHUNK_DURATION_MS / 1000)
# Send audio to server every N ms (accumulate VAD-approved chunks)
SEND_INTERVAL_MS = 250
SEND_INTERVAL_CHUNKS = max(1, SEND_INTERVAL_MS // CHUNK_DURATION_MS)

# VAD settings
VAD_THRESHOLD = 0.5       # speech probability threshold
SPEECH_PAD_MS = 300       # keep sending this long after speech ends
SPEECH_PAD_CHUNKS = max(1, SPEECH_PAD_MS // CHUNK_DURATION_MS)


def load_vad():
    """Load Silero VAD model."""
    try:
        import torch
        model, utils = torch.hub.load(
            repo_or_dir='snakers4/silero-vad',
            model='silero_vad',
            trust_repo=True
        )
        return model
    except Exception as e:
        print(f"[WARN] Could not load Silero VAD: {e}")
        print("[WARN] Running without VAD — all audio will be sent to server")
        return None


def list_audio_devices():
    """List available audio input devices."""
    import sounddevice as sd
    print("\nAvailable audio input devices:")
    print("-" * 60)
    devices = sd.query_devices()
    for i, d in enumerate(devices):
        if d['max_input_channels'] > 0:
            marker = " <-- default" if i == sd.default.device[0] else ""
            print(f"  [{i}] {d['name']} (in: {d['max_input_channels']}ch, "
                  f"rate: {d['default_samplerate']:.0f}Hz){marker}")
    print()


class STTClient:
    def __init__(self, server_url, language="pl", stream_id="default",
                 device=None, use_vad=True, timestamps=False, diarize=False,
                 itn=False, detect_emotion=False, verbose=False):
        self.server_url = server_url
        self.language = language
        self.stream_id = stream_id
        self.device = device
        self.use_vad = use_vad
        self.timestamps = timestamps
        self.diarize = diarize
        self.itn = itn
        self.detect_emotion = detect_emotion
        self.verbose = verbose

        self.running = False
        self.ws = None
        self.vad_model = None
        self.speech_active = False
        self.silence_chunks = 0
        self.send_buffer = bytearray()
        self.chunk_count = 0

    def build_url(self):
        url = f"{self.server_url}/ws/transcribe?language={self.language}"
        url += f"&rate={SAMPLE_RATE}&stream_id={self.stream_id}"
        if self.timestamps:
            url += "&timestamps=1"
        if self.diarize:
            url += "&diarize=1"
        if self.itn:
            url += "&itn=1"
        if self.detect_emotion:
            url += "&detect_emotion=1"
        return url

    def on_message(self, ws, message):
        try:
            msg = json.loads(message)
            if msg.get("error"):
                print(f"\r[ERROR] {msg['error']}", file=sys.stderr)
                return

            text = msg.get("text", "")
            msg_type = msg.get("type", "")
            confidence = msg.get("confidence", 0)
            duration = msg.get("duration", 0)
            emotion = msg.get("emotion")
            speakers = msg.get("speakers", [])

            if msg_type == "partial":
                # Overwrite current line with partial text
                conf_str = f" [{confidence*100:.0f}%]" if confidence else ""
                sys.stdout.write(f"\r\033[K[{self.stream_id}] {text}{conf_str}")
                sys.stdout.flush()
            elif msg_type == "final":
                # Print final on new line
                extras = ""
                if emotion:
                    extras += f" [emo:{emotion}]"
                if speakers:
                    spk_str = ", ".join(f"{s['speaker']}: {s['text']}" for s in speakers)
                    extras += f" [speakers: {spk_str}]"
                conf_str = f" [{confidence*100:.0f}%]" if confidence else ""
                print(f"\r\033[K[{self.stream_id}] {text}{conf_str}{extras}")

            if self.verbose:
                words = msg.get("words", [])
                if words:
                    word_str = " | ".join(
                        f"{w['word']}({w['confidence']*100:.0f}%)"
                        for w in words
                    )
                    print(f"  words: {word_str}", file=sys.stderr)

        except json.JSONDecodeError:
            pass

    def on_error(self, ws, error):
        print(f"\r[WS ERROR] {error}", file=sys.stderr)

    def on_close(self, ws, close_status_code, close_msg):
        print(f"\r[WS] Disconnected (code={close_status_code})", file=sys.stderr)

    def on_open(self, ws):
        print(f"[WS] Connected to {self.server_url} (stream: {self.stream_id})",
              file=sys.stderr)
        # Start audio capture in a separate thread
        threading.Thread(target=self.audio_loop, daemon=True).start()

    def audio_callback(self, indata, frames, time_info, status):
        """Called by sounddevice for each audio chunk."""
        if status:
            print(f"[AUDIO] {status}", file=sys.stderr)

        # Convert to int16
        audio_f32 = indata[:, 0]  # mono
        pcm16 = (audio_f32 * 32767).astype(np.int16)

        if self.use_vad and self.vad_model is not None:
            import torch
            # Run VAD
            audio_tensor = torch.from_numpy(audio_f32).float()
            speech_prob = self.vad_model(audio_tensor, SAMPLE_RATE).item()

            if speech_prob >= VAD_THRESHOLD:
                self.speech_active = True
                self.silence_chunks = 0
                self.send_buffer.extend(pcm16.tobytes())
            elif self.speech_active:
                # Pad: keep sending for a bit after speech ends
                self.silence_chunks += 1
                self.send_buffer.extend(pcm16.tobytes())
                if self.silence_chunks >= SPEECH_PAD_CHUNKS:
                    self.speech_active = False
                    self.silence_chunks = 0
                    # Send remaining buffer
                    self._flush_buffer()
            # If not speech and not padding, don't add to buffer
        else:
            # No VAD — send everything
            self.send_buffer.extend(pcm16.tobytes())

        self.chunk_count += 1
        if self.chunk_count >= SEND_INTERVAL_CHUNKS:
            self._flush_buffer()
            self.chunk_count = 0

    def _flush_buffer(self):
        """Send accumulated audio buffer to server."""
        if len(self.send_buffer) > 0 and self.ws and self.ws.sock and self.ws.sock.connected:
            try:
                self.ws.send(bytes(self.send_buffer), opcode=0x2)  # binary
                if self.verbose:
                    dur_ms = len(self.send_buffer) / (SAMPLE_RATE * 2) * 1000
                    print(f"  [SEND] {len(self.send_buffer)} bytes ({dur_ms:.0f}ms)",
                          file=sys.stderr)
            except Exception:
                pass
            self.send_buffer.clear()

    def audio_loop(self):
        """Capture audio from microphone using sounddevice."""
        import sounddevice as sd

        try:
            with sd.InputStream(
                samplerate=SAMPLE_RATE,
                channels=CHANNELS,
                dtype='float32',
                blocksize=CHUNK_SAMPLES,
                device=self.device,
                callback=self.audio_callback,
            ):
                print(f"[MIC] Recording (rate={SAMPLE_RATE}, chunk={CHUNK_DURATION_MS}ms, "
                      f"VAD={'on' if self.use_vad and self.vad_model else 'off'})",
                      file=sys.stderr)
                while self.running:
                    time.sleep(0.1)
        except Exception as e:
            print(f"[MIC ERROR] {e}", file=sys.stderr)
            self.running = False

    def run(self):
        """Connect and stream. Auto-reconnects on failure."""
        import websocket

        if self.use_vad:
            print("[VAD] Loading Silero VAD model...", file=sys.stderr)
            self.vad_model = load_vad()

        self.running = True
        url = self.build_url()

        while self.running:
            try:
                print(f"[WS] Connecting to {url}...", file=sys.stderr)
                self.ws = websocket.WebSocketApp(
                    url,
                    on_open=self.on_open,
                    on_message=self.on_message,
                    on_error=self.on_error,
                    on_close=self.on_close,
                )
                self.ws.run_forever(ping_interval=30, ping_timeout=10)
            except KeyboardInterrupt:
                self.running = False
                break
            except Exception as e:
                print(f"[WS] Connection failed: {e}", file=sys.stderr)

            if self.running:
                print("[WS] Reconnecting in 3s...", file=sys.stderr)
                time.sleep(3)

        # Send end signal
        if self.ws and self.ws.sock and self.ws.sock.connected:
            try:
                self.ws.send(b"", opcode=0x2)
                self.ws.close()
            except Exception:
                pass
        print("\n[DONE] Client stopped.", file=sys.stderr)


def main():
    parser = argparse.ArgumentParser(
        description="STT Streaming Client — streams mic audio to stt.mm.mk with VAD")
    parser.add_argument("--server", default=DEFAULT_SERVER,
                        help=f"WebSocket server URL (default: {DEFAULT_SERVER})")
    parser.add_argument("--language", "-l", default="pl",
                        help="Language code (default: pl)")
    parser.add_argument("--stream-id", "-s", default=None,
                        help="Stream identifier (default: hostname)")
    parser.add_argument("--device", "-d", type=int, default=None,
                        help="Audio input device index (use --list-devices to see)")
    parser.add_argument("--list-devices", action="store_true",
                        help="List audio devices and exit")
    parser.add_argument("--no-vad", action="store_true",
                        help="Disable VAD, send all audio")
    parser.add_argument("--vad-threshold", type=float, default=VAD_THRESHOLD,
                        help=f"VAD speech threshold 0-1 (default: {VAD_THRESHOLD})")
    parser.add_argument("--timestamps", action="store_true",
                        help="Enable timestamps")
    parser.add_argument("--diarize", action="store_true",
                        help="Enable speaker diarization")
    parser.add_argument("--itn", action="store_true",
                        help="Enable inverse text normalization")
    parser.add_argument("--emotion", action="store_true",
                        help="Enable emotion detection")
    parser.add_argument("--verbose", "-v", action="store_true",
                        help="Verbose output (show per-word confidence)")

    args = parser.parse_args()

    if args.list_devices:
        list_audio_devices()
        sys.exit(0)

    if args.stream_id is None:
        import socket
        args.stream_id = socket.gethostname()

    global VAD_THRESHOLD
    VAD_THRESHOLD = args.vad_threshold

    client = STTClient(
        server_url=args.server,
        language=args.language,
        stream_id=args.stream_id,
        device=args.device,
        use_vad=not args.no_vad,
        timestamps=args.timestamps,
        diarize=args.diarize,
        itn=args.itn,
        detect_emotion=args.emotion,
        verbose=args.verbose,
    )

    # Handle Ctrl+C gracefully
    def sigint_handler(sig, frame):
        client.running = False
    signal.signal(signal.SIGINT, sigint_handler)

    client.run()


if __name__ == "__main__":
    main()
