feat: checking file mime type before sending it to Whisper

This commit is contained in:
MrRaph_
2024-12-13 15:29:42 +01:00
parent 693472186d
commit 5ba28d859c
3 changed files with 7 additions and 3 deletions

View File

@@ -1,6 +1,6 @@
maubot: 0.1.0 maubot: 0.1.0
id: nigzu.com.maubot-stt id: nigzu.com.maubot-stt
version: 0.3.0 version: 0.3.1
license: MIT license: MIT
modules: modules:
- openai-whisper - openai-whisper

Binary file not shown.

View File

@@ -12,6 +12,9 @@ from mautrix.errors import MatrixRequestError
from mautrix.types import EventType, MessageType, RelationType, TextMessageEventContent, Format,RelatesTo,InReplyTo from mautrix.types import EventType, MessageType, RelationType, TextMessageEventContent, Format,RelatesTo,InReplyTo
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
ALLOWED_EXTENSIONS = ['flac', 'm4a', 'mp3', 'mp4', 'mpeg', 'mpga', 'oga', 'ogg', 'wav', 'webm']
ALLOWED_MIME_TYPES = ['audio/flac','audio/mp4','video/mpeg','audio/ogg','audio/wav','video/webm']
class Config(BaseProxyConfig): class Config(BaseProxyConfig):
def do_update(self, helper: ConfigUpdateHelper) -> None: def do_update(self, helper: ConfigUpdateHelper) -> None:
helper.copy("whisper_endpoint") helper.copy("whisper_endpoint")
@@ -45,6 +48,9 @@ class WhisperPlugin(Plugin):
if self.allowed_rooms and event.room_id not in self.allowed_rooms: if self.allowed_rooms and event.room_id not in self.allowed_rooms:
return False return False
if event.content.info.mimetype not in ALLOWED_MIME_TYPES:
return False
return event.content.msgtype == MessageType.AUDIO or event.content.msgtype == MessageType.FILE return event.content.msgtype == MessageType.AUDIO or event.content.msgtype == MessageType.FILE
@@ -56,8 +62,6 @@ class WhisperPlugin(Plugin):
try: try:
await event.mark_read() await event.mark_read()
await self.client.set_typing(event.room_id, timeout=99999) await self.client.set_typing(event.room_id, timeout=99999)
self.log.error(event)
audio_bytes = await self.client.download_media(url=event.content.url) audio_bytes = await self.client.download_media(url=event.content.url)
transcription = await self.transcribe_audio(audio_bytes) transcription = await self.transcribe_audio(audio_bytes)