I am trying to unit test a repository class in android which is using paging with a remote mediator and paging source.
But when I run the test the returned result is empty, although actual should contain the test item.
Like below :
Here is my repository
class PostsRepository @Inject constructor(
private val postsApi: AutomatticPostsApi,
private val postsDao: PostDao
) : IPostsRepository {
@ExperimentalPagingApi
override fun loadPosts(): Flow<PagingData<Post>> {
println("loadPosts")
return Pager(
config = PagingConfig(20),
initialKey = 1,
remoteMediator = PostsPageRemoteMediator(
postsApi,
postsDao
),
pagingSourceFactory = { postsDao.getPostsPagingSource() }
).flow.map { pagingData ->
pagingData.map { it.toPost() }
}
}
}
Here is my UT
@ExperimentalCoroutinesApi
@ExperimentalPagingApi
class PostsRepositoryTest {
@get:Rule
val instantTaskExecutorRule = InstantTaskExecutorRule()
private val coroutineDispatcher = TestCoroutineDispatcher()
private lateinit var postDao: FakePostDao
private lateinit var postsApi: CommonAutomatticPostsApi
private val remotePosts = listOf(createDummyPostResponse())
private val domainPosts = remotePosts.map { it.toPost() }
//GIVEN: subject under test
private lateinit var postsRepository: PostsRepository
@Before
fun createRepository() = coroutineDispatcher.runBlockingTest {
postsApi = CommonAutomatticPostsApi(remotePosts.toMutableList())
postDao = FakePostDao()
postsRepository = PostsRepository(postsApi, postDao)
}
@Test
fun loadPosts_returnsCorrectPosts() = runBlockingTest {
//WHEN: posts are retrieved from paging source
launch {
postsRepository.loadPosts().collect { pagingData ->
val posts = mutableListOf<Post>()
pagingData.map {
posts.add(it)
println(it)
}
//THEN: retrieved posts should be the remotePosts
assertThat(posts, IsEqual(domainPosts))
}
}
}
}
Here is the FakeApi, FakePagingSource and FakeDao
class CommonAutomatticPostsApi(val posts: MutableList<PostResponse> = mutableListOf()) : AutomatticPostsApi {
companion object {
const val SUBSCRIBER_COUNT = 2L
const val AUTHOR_NAME = "RR"
}
override suspend fun loadPosts(page: Int, itemCount: Int): PostsResponse {
println("Loaded")
return PostsResponse(posts.size.toLong(), posts)
}
}
class FakePostsPagingSource() : PagingSource<Int, PostEntity>() {
var triggerError = false
var posts: List<PostEntity> = emptyList()
set(value) {
println("set")
field = value
invalidate()
}
override suspend fun load(params: LoadParams<Int>): LoadResult<Int, PostEntity> {
println("load")
if (triggerError) {
return LoadResult.Error(Exception("A test error triggered"))
}
println("not error")
return LoadResult.Page(
data = posts,
prevKey = null,
nextKey = null
)
}
override fun getRefreshKey(state: PagingState<Int, PostEntity>): Int? {
println("refresh")
return state.anchorPosition ?: 1
}
}
class FakePostDao(val posts: MutableList<PostEntity> = mutableListOf()) : PostDao {
val pagingSource = FakePostsPagingSource()
override suspend fun insertPosts(posts: List<PostEntity>) {
this.posts.addAll(posts)
println("insertPosts")
updatePagingSource()
}
override suspend fun updatePost(post: PostEntity) {
onValidPost(post.id) {
posts[it] = post
updatePagingSource()
}
}
private fun onValidPost(postId: Long, block: (index: Int) -> Unit): Boolean {
println("onValidPost")
val index = posts.indexOfFirst { it.id == postId }
if (index != -1) {
block(index)
return true
}
return false
}
override suspend fun updatePost(postId: Long, subscriberCount: Long) {
onValidPost(postId) {
posts[it] = posts[it].copy(subscriberCount = subscriberCount)
updatePagingSource()
}
}
override suspend fun getPostById(postId: Long): PostEntity? {
val index = posts.indexOfFirst { it.id == postId }
return if (index != -1) {
posts[index]
} else {
null
}
}
override suspend fun getPosts(): List<PostEntity> {
println("getPosts")
return posts
}
override fun getPostsPagingSource(): PagingSource<Int, PostEntity> {
println("getPostsPagingSource")
return pagingSource
}
override suspend fun clearAll() {
posts.clear()
updatePagingSource()
}
private fun updatePagingSource() {
println("updatePagingSource")
pagingSource.posts = posts
}
@Transaction
override suspend fun refreshPosts(newPosts: List<PostEntity>) {
println("refreshPosts")
clearAll()
insertPosts(newPosts)
}
}