Unit testing a repository with paging 3 using a a remote mediator and paging source
Asked Answered
T

2

4

I am trying to unit test a repository class in android which is using paging with a remote mediator and paging source.

But when I run the test the returned result is empty, although actual should contain the test item.

Like below :

Here is my repository

class PostsRepository @Inject constructor(
    private val postsApi: AutomatticPostsApi,
    private val postsDao: PostDao
) : IPostsRepository {

    @ExperimentalPagingApi
    override fun loadPosts(): Flow<PagingData<Post>> {
        println("loadPosts")
        return Pager(
            config = PagingConfig(20),
            initialKey = 1,
            remoteMediator = PostsPageRemoteMediator(
                postsApi,
                postsDao
            ),
            pagingSourceFactory = { postsDao.getPostsPagingSource() }
        ).flow.map { pagingData ->
            pagingData.map { it.toPost() }
        }
    }

}

Here is my UT

@ExperimentalCoroutinesApi
@ExperimentalPagingApi
class PostsRepositoryTest {
    @get:Rule
    val instantTaskExecutorRule = InstantTaskExecutorRule()

    private val coroutineDispatcher = TestCoroutineDispatcher()
    private lateinit var postDao: FakePostDao
    private lateinit var postsApi: CommonAutomatticPostsApi
    private val remotePosts = listOf(createDummyPostResponse())
    private val domainPosts = remotePosts.map { it.toPost() }

    //GIVEN: subject under test
    private lateinit var postsRepository: PostsRepository

    @Before
    fun createRepository() =  coroutineDispatcher.runBlockingTest {
        postsApi = CommonAutomatticPostsApi(remotePosts.toMutableList())
        postDao = FakePostDao()
        postsRepository = PostsRepository(postsApi, postDao)
    }

    @Test
    fun loadPosts_returnsCorrectPosts() = runBlockingTest {
        //WHEN: posts are retrieved from paging source

        launch {

            postsRepository.loadPosts().collect { pagingData ->

                val posts = mutableListOf<Post>()
                pagingData.map {

                    posts.add(it)
                    println(it)
                }

                //THEN: retrieved posts should be the remotePosts
                assertThat(posts, IsEqual(domainPosts))
            }

        }

    }
}

Here is the FakeApi, FakePagingSource and FakeDao

class CommonAutomatticPostsApi(val posts: MutableList<PostResponse> = mutableListOf()) : AutomatticPostsApi {
    companion object {
        const val SUBSCRIBER_COUNT = 2L
        const val AUTHOR_NAME = "RR"
    }

    override suspend fun loadPosts(page: Int, itemCount: Int): PostsResponse {
        println("Loaded")
        return PostsResponse(posts.size.toLong(), posts)
    }
}

class FakePostsPagingSource() : PagingSource<Int, PostEntity>() {
    var triggerError = false
    var posts: List<PostEntity> = emptyList()
        set(value) {
            println("set")
            field = value
            invalidate()
        }

    override suspend fun load(params: LoadParams<Int>): LoadResult<Int, PostEntity> {
        println("load")
        if (triggerError) {
            return LoadResult.Error(Exception("A test error triggered"))
        }
        println("not error")

        return LoadResult.Page(
            data = posts,
            prevKey = null,
            nextKey = null
        )
    }

    override fun getRefreshKey(state: PagingState<Int, PostEntity>): Int? {
        println("refresh")

        return state.anchorPosition ?: 1
    }
}

class FakePostDao(val posts: MutableList<PostEntity> = mutableListOf()) : PostDao {
    val pagingSource = FakePostsPagingSource()

    override suspend fun insertPosts(posts: List<PostEntity>) {
        this.posts.addAll(posts)
        println("insertPosts")
        updatePagingSource()
    }

    override suspend fun updatePost(post: PostEntity) {
        onValidPost(post.id) {
            posts[it] = post
            updatePagingSource()
        }
    }

    private fun onValidPost(postId: Long, block: (index: Int) -> Unit): Boolean {
        println("onValidPost")

        val index = posts.indexOfFirst { it.id == postId }
        if (index != -1) {
            block(index)
            return true
        }

        return false
    }

    override suspend fun updatePost(postId: Long, subscriberCount: Long) {
        onValidPost(postId) {
            posts[it] = posts[it].copy(subscriberCount = subscriberCount)
            updatePagingSource()
        }
    }

    override suspend fun getPostById(postId: Long): PostEntity? {
        val index = posts.indexOfFirst { it.id == postId }
        return if (index != -1) {
            posts[index]
        } else {
            null
        }
    }

    override suspend fun getPosts(): List<PostEntity> {
        println("getPosts")

        return posts
    }

    override fun getPostsPagingSource(): PagingSource<Int, PostEntity> {
        println("getPostsPagingSource")

        return pagingSource
    }

    override suspend fun clearAll() {
        posts.clear()
        updatePagingSource()
    }

    private fun updatePagingSource() {
        println("updatePagingSource")

        pagingSource.posts = posts
    }

    @Transaction
    override suspend fun refreshPosts(newPosts: List<PostEntity>) {
        println("refreshPosts")
        clearAll()
        insertPosts(newPosts)
    }
}
Teofilateosinte answered 6/3, 2021 at 8:46 Comment(0)
C
8

UPDATE 28/8 :

There is a complete topic here about best practice of testing pagination.

OLD ANSWER :

If you look in paging library source code, you see fetching data from dao or remote mediator does not start unless something starts collecting from the pagingData. I have found some utility function to start collecting from pagingData:

suspend fun <T : Any> PagingData<T>.collectData(): List<T> {
    val dcb = object : DifferCallback {
        override fun onChanged(position: Int, count: Int) {}
        override fun onInserted(position: Int, count: Int) {}
        override fun onRemoved(position: Int, count: Int) {}
    }
    val items = mutableListOf<T>()
    val dif = object : PagingDataDiffer<T>(dcb, TestCoroutineDispatcher()) {
        override suspend fun presentNewList(
            previousList: NullPaddedList<T>,
            newList: NullPaddedList<T>,
            newCombinedLoadStates: CombinedLoadStates,
            lastAccessedIndex: Int
        ): Int? {
            for (idx in 0 until newList.size)
                items.add(newList.getFromStorage(idx))
            return null
        }
    }
    dif.collectFrom(this)
    return items
}

and you can use it like this:

postsRepository.loadPosts().collect { pagingData ->
        val posts = pagingData.collectData ()
        //THEN: retrieved posts should be the remotePosts
        assertThat(posts, IsEqual(domainPosts))
    }

I have tried to write tests like this but I think its's not the best way to test the pagination, because instead of testing your own code you try to test the paging library. The better alternative would be testing your dao and pagingSource and your mappers.

Coffeecolored answered 18/3, 2021 at 7:58 Comment(0)
A
8

PagingData test (kotlin), version: 3.0.0-rc01. Thanks to Farid.

   private suspend fun <T : Any> PagingData<T>.collectDataForTest(): List<T> {
        val dcb = object : DifferCallback {
            override fun onChanged(position: Int, count: Int) {}
            override fun onInserted(position: Int, count: Int) {}
            override fun onRemoved(position: Int, count: Int) {}
        }
        val items = mutableListOf<T>()
        val dif = object : PagingDataDiffer<T>(dcb, TestCoroutineDispatcher()) {
            override suspend fun presentNewList(
                previousList: NullPaddedList<T>,
                newList: NullPaddedList<T>,
                newCombinedLoadStates: CombinedLoadStates,
                lastAccessedIndex: Int,
                onListPresentable: () -> Unit
            ): Int? {
                for (idx in 0 until newList.size)
                    items.add(newList.getFromStorage(idx))
                onListPresentable()
                return null
            }
        }
        dif.collectFrom(this)
        return items
    }

Usage:

    // searchHistoryList: Flow<PagingData<Your Data type>>
    val tmp = useCase.searchHistoryList.take(1).toList().first()
    // result: List<Your Data type>
    val result = tmp.collectDataForTest()
    assertEquals(expect, result)
Arium answered 5/5, 2021 at 5:12 Comment(1)
This implementation works for liveData? I didnt understand the usage of the extension function. You used it inside your unit test?Dato

© 2022 - 2024 — McMap. All rights reserved.