Custom Persistence class for python-telegram-bot
Asked Answered
N

2

6

I am developing a simple Telegram chat bot using python-telegram-bot library. My bot is currently using the ConversationHandler to keep track of the state of the conversation.

I want to make the conversation persistent by storing the conversation state in a MongoDB database.

I am using mongoengine library for python to communicate with my DB.

By reading the documentation for BasePersistence (https://python-telegram-bot.readthedocs.io/en/stable/telegram.ext.basepersistence.html) I understood that is necessary to extend this class with a custom one, let's call it MongoPersistence, and to overwrite the following methods:

  • get_conversations(name)
  • update_conversation(name, key, new_state)

The documentation doesn't specify the structure the dict returned by get_conversations(name) has and so it's also difficult to understand how to implement update_conversation(name, key, new_state)

Suppose I have the above mentioned class (store_user_data, store_chat_data, store_bot_data are all set to False because I don't want to store this data):

from telegram.ext import BasePersistence


class MongoPersistence(BasePersistence):

    def __init__(self):
        super(MongoPersistence, self).__init__(store_user_data=False,
                                               store_chat_data=False,
                                               store_bot_data=False)

    def get_conversations(self, name):
        pass

    def update_conversation(self, name, key, new_state):
        pass

How can I implement this class so that my conversation state will be fetched and saved from the DB?

Nondisjunction answered 21/3, 2020 at 1:19 Comment(0)
M
11

Conversation Persistence

I guess the simpliest way to implement it is looking at PicklePersistence(). The only example I've seen of the dictionary is conversations = { name : { (user_id,user_id): state} } where name is the one given to ConversationHandler(), the tuple-as-a-key (user_id,user_id) is the user_id to whom your bot is talking to and state is the state of the conversation. Ok, maybe one isn't user_id, maybe is chat_id but I can't say for sure, I need more guinea pigs.

To handle the tuple-as-a-key, python-telegram-bot includes some tools to help you handle that: encode_conversations_to_json and decode_conversations_from_json.

Here, on_flush is a variable to tell the code if you want to save everytime there is a call to update_conversation() when is set to False or only when exiting the program when is set to True

One last detail: for now the following code only saves and retrieve from the database but there is no replacing nor deleting.

from telegram.ext import BasePersistence
from config import mongo_URI
from copy import deepcopy
from telegram.utils.helpers import decode_conversations_from_json, encode_conversations_to_json
import mongoengine
import json
from bson import json_util

class Conversations(mongoengine.Document):
    obj = mongoengine.DictField()
    meta = { 'collection': 'Conversations', 'ordering': ['-id']}

class MongoPersistence(BasePersistence):

    def __init__(self):
        super(MongoPersistence, self).__init__(store_user_data=False,
                                               store_chat_data=False,
                                               store_bot_data=False)
        dbname = "persistencedb"
        mongoengine.connect(host=mongo_URI, db=dbname)
        self.conversation_collection = "Conversations"
        self.conversations = None
        self.on_flush = False

    def get_conversations(self, name):
        if self.conversations:
            pass
        else:
            document = Conversations.objects()
            if document.first() == None:
                document = {}
            else:
                document = document.first()['obj']
            conversations_json = json_util.dumps(document)
            self.conversations = decode_conversations_from_json(conversations_json)
        return self.conversations.get(name, {}).copy()

    def update_conversation(self, name, key, new_state):
        if self.conversations.setdefault(name, {}).get(key) == new_state:
            return
        self.conversations[name][key] = new_state
        if not self.on_flush:
            conversations_dic = json_util.loads(encode_conversations_to_json(self.conversations))
            document = Conversations(obj=conversations_dic)
            document.save()

    def flush(self):
        conversations_dic = json_util.loads(encode_conversations_to_json(self.conversations))
        document = Conversations(obj=conversations_dic)
        document.save()
        mongoengine.disconnect()

BEWARE! Sometimes the conversations requires to user previosly setted user_data and this code doesn't provide it as requested.

All Persistence

Here is a more complete code (still lack the replace document in the database).

from telegram.ext import BasePersistence
from collections import defaultdict
from config import mongo_URI
from copy import deepcopy
from telegram.utils.helpers import decode_user_chat_data_from_json, decode_conversations_from_json, encode_conversations_to_json
import mongoengine
import json
from bson import json_util

class Conversations(mongoengine.Document):
    obj = mongoengine.DictField()
    meta = { 'collection': 'Conversations', 'ordering': ['-id']}

class UserData(mongoengine.Document):
    obj = mongoengine.DictField()
    meta = { 'collection': 'UserData', 'ordering': ['-id']}

class ChatData(mongoengine.Document):
    obj = mongoengine.DictField()
    meta = { 'collection': 'ChatData', 'ordering': ['-id']}

class BotData(mongoengine.Document):
    obj = mongoengine.DictField()
    meta = { 'collection': 'BotData', 'ordering': ['-id']}

class DBHelper():
    """Class to add and get documents from a mongo database using mongoengine
    """
    def __init__(self, dbname="persistencedb"):
        mongoengine.connect(host=mongo_URI, db=dbname)
    def add_item(self, data, collection):
        if collection == "Conversations":
            document = Conversations(obj=data)
        elif collection == "UserData":
            document = UserData(obj=data)
        elif collection == "chat_data_collection":
            document = ChatData(obj=data)
        else:
            document = BotData(obj=data)
        document.save()
    def get_item(self, collection):
        if collection == "Conversations":
            document = Conversations.objects()
        elif collection == "UserData":
            document = UserData.objects()
        elif collection == "ChatData":
            document = ChatData.objects()
        else:
            document = BotData.objects()
        if document.first() == None:
            document = {}
        else:
            document = document.first()['obj']

        return document
    def close(self):
        mongoengine.disconnect()

class DBPersistence(BasePersistence):
    """Uses DBHelper to make the bot persistant on a database.
       It's heavily inspired on PicklePersistence from python-telegram-bot
    """
    def __init__(self):
        super(DBPersistence, self).__init__(store_user_data=True,
                                               store_chat_data=True,
                                               store_bot_data=True)
        self.persistdb = "persistancedb"
        self.conversation_collection = "Conversations"
        self.user_data_collection = "UserData"
        self.chat_data_collection = "ChatData"
        self.bot_data_collection = "BotData"
        self.db = DBHelper()
        self.user_data = None
        self.chat_data = None
        self.bot_data = None
        self.conversations = None
        self.on_flush = False

    def get_conversations(self, name):
        if self.conversations:
            pass
        else:
            conversations_json = json_util.dumps(self.db.get_item(self.conversation_collection))
            self.conversations = decode_conversations_from_json(conversations_json)
        return self.conversations.get(name, {}).copy()

    def update_conversation(self, name, key, new_state):
        if self.conversations.setdefault(name, {}).get(key) == new_state:
            return
        self.conversations[name][key] = new_state
        if not self.on_flush:
            conversations_json = json_util.loads(encode_conversations_to_json(self.conversations))
            self.db.add_item(conversations_json, self.conversation_collection)

    def get_user_data(self):
        if self.user_data:
            pass
        else:
            user_data_json = json_util.dumps(self.db.get_item(self.user_data_collection))
            if user_data_json != '{}':
                self.user_data = decode_user_chat_data_from_json(user_data_json)
            else:
                self.user_data = defaultdict(dict,{})
        return deepcopy(self.user_data)

    def update_user_data(self, user_id, data):
        if self.user_data is None:
            self.user_data = defaultdict(dict)
        # comment next line if you want to save to db every time this function is called
        if self.user_data.get(user_id) == data:
            return
        self.user_data[user_id] = data
        if not self.on_flush:
            user_data_json = json_util.loads(json.dumps(self.user_data))
            self.db.add_item(user_data_json, self.user_data_collection)

    def get_chat_data(self):
        if self.chat_data:
            pass
        else:
            chat_data_json = json_util.dumps(self.db.get_item(self.chat_data_collection))
            if chat_data_json != "{}":
                self.chat_data = decode_user_chat_data_from_json(chat_data_json)
            else:
                self.chat_data = defaultdict(dict,{})
        return deepcopy(self.chat_data)

    def update_chat_data(self, chat_id, data):
        if self.chat_data is None:
            self.chat_data = defaultdict(dict)
        # comment next line if you want to save to db every time this function is called
        if self.chat_data.get(chat_id) == data:
            return
        self.chat_data[chat_id] = data
        if not self.on_flush:
            chat_data_json = json_util.loads(json.dumps(self.chat_data))
            self.db.add_item(chat_data_json, self.chat_data_collection)

    def get_bot_data(self):
        if self.bot_data:
            pass
        else:
            bot_data_json = json_util.dumps(self.db.get_item(self.bot_data_collection))
            self.bot_data = json.loads(bot_data_json)
        return deepcopy(self.bot_data)

    def update_bot_data(self, data):
        if self.bot_data == data:
            return
        self.bot_data = data.copy()
        if not self.on_flush:
            bot_data_json = json_util.loads(json.dumps(self.bot_data))
            self.db.add_item(self.bot_data, self.bot_data_collection)

    def flush(self):
        if self.conversations:
            conversations_json = json_util.loads(encode_conversations_to_json(self.conversations))
            self.db.add_item(conversations_json, self.conversation_collection)
        if self.user_data:
            user_data_json = json_util.loads(json.dumps(self.user_data))
            self.db.add_item(user_data_json, self.user_data_collection)
        if self.chat_data:
            chat_data_json = json_util.loads(json.dumps(self.chat_data))
            self.db.add_item(chat_data_json, self.chat_data_collection)
        if self.bot_data:
            bot_data_json = json_util.loads(json.dumps(self.bot_data))
            self.db.add_item(self.bot_data, self.bot_data_collection)
        self.db.close()

Two details:

  1. Chat_data persistence hasn't been saving to the database for now. Needs more testing. Maybe that part of the code has a bug.
  2. For now the only part of the code where on_flush = False works is in Conversations. In all other updates it seems the call is done after the assignment so if variable[key] == data is always True and finish the code earlier than saving to database, that's why there is a comment saying # comment next line if you want to save to db every time this function is called but makes a lot of savings. If you set on_flush = True and the code stops earlier (the process is killed for example) you won't save anything on the database.
Marisolmarissa answered 3/6, 2020 at 13:9 Comment(0)
E
0

Here is what we ended up. It is based on MongoEngine and it works. Would be happy to hear feedback. It keeps conversations state persistent after bot restart.

from typing import DefaultDict, Optional, Tuple, cast, Any
from collections import defaultdict
import json

from mongoengine import connect, Document, StringField, IntField, DictField, ListField, DynamicField
from telegram.ext import BasePersistence
from telegram.ext._utils.types import BD, CD, UD, CDCData, ConversationDict

class BaseDocument(Document):
    meta = {'abstract': True, 'allow_inheritance': True}

    @classmethod
    def set_namespace(cls, namespace: str):
        if namespace:
            cls._meta['collection'] = f"{namespace}_{cls._meta['collection']}"

class BotData(BaseDocument):
    data = DictField()
    meta = {'collection': 'bot_data'}

class ChatData(BaseDocument):
    chat_id = IntField(required=True)
    data = DictField()
    meta = {'collection': 'chat_data'}

class UserData(BaseDocument):
    user_id = IntField(required=True)
    data = DictField()
    meta = {'collection': 'user_data'}

class CallbackData(BaseDocument):
    data = ListField()
    meta = {'collection': 'callback_data'}

class ConversationData(BaseDocument):
    name = StringField(required=True)
    key = StringField(required=True)
    state = DynamicField()
    meta = {'collection': 'conversation_data'}

class MongoPersistence(BasePersistence[UD, CD, BD]):
    def __init__(
        self,
        mongo_url: str,
        database_name: str,
        namespace: str = "",
        store_data = None,
        update_interval: int = 60
    ):
        super().__init__(
            store_data=store_data,
            update_interval=update_interval,
        )
        self._namespace = namespace.strip()  # Remove any leading/trailing whitespace
        connect(db=database_name, host=mongo_url)
        
        # Set namespace for all document classes
        for cls in [BotData, ChatData, UserData, CallbackData, ConversationData]:
            cls.set_namespace(self._namespace)

    async def get_bot_data(self) -> BD:
        try:
            return dict(BotData.objects.get().data)
        except BotData.DoesNotExist:
            return {}

    async def update_bot_data(self, data: BD) -> None:
        BotData.objects().update_one(set__data=data, upsert=True)

    async def refresh_bot_data(self, bot_data: BD) -> None:
        if isinstance(bot_data, dict):
            stored_data = await self.get_bot_data()
            bot_data.clear()
            bot_data.update(stored_data)

    async def get_chat_data(self) -> DefaultDict[int, CD]:
        chat_data = defaultdict(dict)
        for data in ChatData.objects():
            chat_data[data.chat_id] = dict(data.data)
        return chat_data

    async def update_chat_data(self, chat_id: int, data: CD) -> None:
        ChatData.objects(chat_id=chat_id).update_one(set__data=data, upsert=True)

    async def refresh_chat_data(self, chat_id: int, chat_data: CD) -> None:
        try:
            stored_data = dict(ChatData.objects.get(chat_id=chat_id).data)
            chat_data.clear()
            chat_data.update(stored_data)
        except ChatData.DoesNotExist:
            pass

    async def get_user_data(self) -> DefaultDict[int, UD]:
        user_data = defaultdict(dict)
        for data in UserData.objects():
            user_data[data.user_id] = dict(data.data)
        return user_data

    async def update_user_data(self, user_id: int, data: UD) -> None:
        UserData.objects(user_id=user_id).update_one(set__data=data, upsert=True)

    async def refresh_user_data(self, user_id: int, user_data: UD) -> None:
        try:
            stored_data = dict(UserData.objects.get(user_id=user_id).data)
            user_data.clear()
            user_data.update(stored_data)
        except UserData.DoesNotExist:
            pass

    async def drop_chat_data(self, chat_id: int) -> None:
        ChatData.objects(chat_id=chat_id).delete()

    async def drop_user_data(self, user_id: int) -> None:
        UserData.objects(user_id=user_id).delete()

    async def get_callback_data(self) -> Optional[CDCData]:
        try:
            cdcdata = CallbackData.objects.get().data
            return cast(CDCData, ([(one, float(two), three) for one, two, three in cdcdata[0]], cdcdata[1]))
        except CallbackData.DoesNotExist:
            return None

    async def update_callback_data(self, data: CDCData) -> None:
        CallbackData.objects().update_one(set__data=data, upsert=True)

    async def get_conversations(self, name: str) -> ConversationDict:
        conversations = {}
        for data in ConversationData.objects(name=name):
            key = tuple(json.loads(data.key))
            conversations[key] = data.state
        return conversations

    async def update_conversation(self, name: str, key: Tuple[int, ...], new_state: Optional[object]) -> None:
        ConversationData.objects(
            name=name,
            key=json.dumps(key, sort_keys=True)
        ).update_one(set__state=new_state, upsert=True)

    async def flush(self) -> None:
        pass

    def close(self) -> None:
        pass
Ebbie answered 18/10 at 16:16 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.