Commit 1d8e0e60 authored by Exbrayat Cédric's avatar Exbrayat Cédric Committed by Jean-Baptiste Nizet
Browse files

feat: add optional aggregations to the search service

parent 990634a3
package fr.inra.urgi.rare.dao;
import fr.inra.urgi.rare.domain.GeneticResource;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.elasticsearch.core.aggregation.AggregatedPage;
/**
* Custom methods of the {@link GeneticResourceDao}
......@@ -12,7 +12,11 @@ public interface GeneticResourceDaoCustom {
/**
* Searches for the given text anywhere (except in identifier, URL and numeric fields) in the genetic resources,
* and returns the requested page (results are sorted by score, in descending order)
* and returns the requested page (results are sorted by score, in descending order).
* @param aggregate if true, terms aggregations are requested and present in the returned value. Otherwise,
* the returned aggregated page has no aggregation.
*/
Page<GeneticResource> search(String query, Pageable page);
AggregatedPage<GeneticResource> search(String query,
boolean aggregate,
Pageable page);
}
......@@ -8,10 +8,10 @@ import java.util.stream.Collectors;
import java.util.stream.Stream;
import fr.inra.urgi.rare.domain.GeneticResource;
import org.springframework.data.domain.Page;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.springframework.data.domain.Pageable;
import org.springframework.data.elasticsearch.core.ElasticsearchTemplate;
import org.springframework.data.elasticsearch.core.query.NativeSearchQuery;
import org.springframework.data.elasticsearch.core.aggregation.AggregatedPage;
import org.springframework.data.elasticsearch.core.query.NativeSearchQueryBuilder;
/**
......@@ -20,6 +20,8 @@ import org.springframework.data.elasticsearch.core.query.NativeSearchQueryBuilde
*/
public class GeneticResourceDaoImpl implements GeneticResourceDaoCustom {
private static final int MAX_BUCKETS = 100;
/**
* Contains the fields searchable on a {@link GeneticResource}.
* This is basically all fields at the exception of a few ones like `identifier`,
......@@ -48,11 +50,19 @@ public class GeneticResourceDaoImpl implements GeneticResourceDaoCustom {
}
@Override
public Page<GeneticResource> search(String query, Pageable page) {
NativeSearchQuery searchQuery = new NativeSearchQueryBuilder()
public AggregatedPage<GeneticResource> search(String query,
boolean aggregate,
Pageable page) {
NativeSearchQueryBuilder builder = new NativeSearchQueryBuilder()
.withQuery(multiMatchQuery(query, SEARCHABLE_FIELDS.toArray(new String[0])))
.withPageable(page)
.build();
return elasticsearchTemplate.queryForPage(searchQuery, GeneticResource.class);
.withPageable(page);
if (aggregate) {
Stream.of(RareAggregation.values()).forEach(rareAggregation ->
builder.addAggregation(AggregationBuilders.terms(rareAggregation.getName())
.field(rareAggregation.getField())
.size(MAX_BUCKETS)));
}
return elasticsearchTemplate.queryForPage(builder.build(), GeneticResource.class);
}
}
package fr.inra.urgi.rare.dao;
/**
* Enum listing the terms aggregations used by RARe, and their corresponding name and field
* @author JB Nizet
*/
public enum RareAggregation {
DOMAIN("domain", "domain.keyword"),
BIOTOPE("biotope", "biotopeType.keyword"),
MATERIAL("material", "materialType.keyword"),
COUNTRY_OF_ORIGIN("coo", "countryOfOrigin.keyword");
private final String name;
private final String field;
RareAggregation(String name, String field) {
this.name = name;
this.field = field;
}
public String getName() {
return name;
}
public String getField() {
return field;
}
}
package fr.inra.urgi.rare.dto;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
import com.fasterxml.jackson.annotation.JsonUnwrapped;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.springframework.data.elasticsearch.core.aggregation.AggregatedPage;
/**
* DTO for a page containing additional aggregations
* @author JB Nizet
*/
public final class AggregatedPageDTO<T> {
@JsonUnwrapped
private final PageDTO<T> page;
private final List<AggregationDTO> aggregations;
public AggregatedPageDTO(PageDTO<T> page, List<AggregationDTO> aggregations) {
this.page = page;
this.aggregations = aggregations;
}
public static <T> AggregatedPageDTO<T> fromPage(AggregatedPage<T> page) {
return new AggregatedPageDTO<>(
PageDTO.fromPage(page),
toAggregationDTOs(page.getAggregations()));
}
public static <T, R> AggregatedPageDTO<R> fromPage(AggregatedPage<T> page, Function<T, R> mapper) {
return new AggregatedPageDTO<>(
PageDTO.fromPage(page, mapper),
toAggregationDTOs(page.getAggregations()));
}
private static List<AggregationDTO> toAggregationDTOs(Aggregations aggregations) {
if (aggregations == null) {
return Collections.emptyList();
}
return aggregations.asList()
.stream()
.filter(aggregation -> aggregation instanceof Terms)
.map(Terms.class::cast)
.map(AggregationDTO::new)
.collect(Collectors.toList());
}
public PageDTO<T> getPage() {
return page;
}
public List<AggregationDTO> getAggregations() {
return aggregations;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
AggregatedPageDTO<?> that = (AggregatedPageDTO<?>) o;
return Objects.equals(page, that.page) &&
Objects.equals(aggregations, that.aggregations);
}
@Override
public int hashCode() {
return Objects.hash(page, aggregations);
}
@Override
public String toString() {
return "AggregatedPageDTO{" +
"page=" + page +
", aggregations=" + aggregations +
'}';
}
}
package fr.inra.urgi.rare.dto;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
/**
* A DTO for a terms aggregation, containing the name of the aggregation (which is used to determine the aggregated
* field), and the various buckets, i.e. the various values of this field.
* @author JB Nizet
*/
public final class AggregationDTO {
private final String name;
private final List<BucketDTO> buckets;
public AggregationDTO(Terms aggregation) {
this.name = aggregation.getName();
this.buckets = Collections.unmodifiableList(
aggregation.getBuckets().stream().map(BucketDTO::new).collect(Collectors.toList())
);
}
public String getName() {
return name;
}
public List<BucketDTO> getBuckets() {
return buckets;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
AggregationDTO that = (AggregationDTO) o;
return Objects.equals(name, that.name) &&
Objects.equals(buckets, that.buckets);
}
@Override
public int hashCode() {
return Objects.hash(name, buckets);
}
@Override
public String toString() {
return "AggregationDTO{" +
"name='" + name + '\'' +
", buckets=" + buckets +
'}';
}
}
package fr.inra.urgi.rare.dto;
import java.util.Objects;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
/**
* A bucket, containing a field value and the number of documents falling into the bucket
* @author JB Nizet
*/
public final class BucketDTO {
private final String key;
private final long documentCount;
public BucketDTO(String key, long documentCount) {
this.key = key;
this.documentCount = documentCount;
}
public BucketDTO(Terms.Bucket bucket) {
this(bucket.getKeyAsString(), bucket.getDocCount());
}
public String getKey() {
return key;
}
public long getDocumentCount() {
return documentCount;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
BucketDTO bucketDTO = (BucketDTO) o;
return documentCount == bucketDTO.documentCount &&
Objects.equals(key, bucketDTO.key);
}
@Override
public int hashCode() {
return Objects.hash(key, documentCount);
}
@Override
public String toString() {
return "BucketDTO{" +
"key='" + key + '\'' +
", documentCount=" + documentCount +
'}';
}
}
......@@ -47,10 +47,10 @@ public final class PageDTO<T> {
public static <T> PageDTO<T> fromPage(Page<T> page) {
return new PageDTO<>(page.getContent(),
page.getNumber(),
page.getSize(),
page.getTotalElements(),
page.getTotalPages());
page.getNumber(),
page.getSize(),
page.getTotalElements(),
page.getTotalPages());
}
public static <T, R> PageDTO<R> fromPage(Page<T> page, Function<T, R> mapper) {
......
......@@ -4,7 +4,7 @@ import java.util.Optional;
import fr.inra.urgi.rare.dao.GeneticResourceDao;
import fr.inra.urgi.rare.domain.GeneticResource;
import fr.inra.urgi.rare.dto.PageDTO;
import fr.inra.urgi.rare.dto.AggregatedPageDTO;
import org.springframework.data.domain.PageRequest;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
......@@ -27,7 +27,13 @@ public class SearchController {
}
@GetMapping
public PageDTO<GeneticResource> search(@RequestParam("query") String query, Optional<Integer> page) {
return PageDTO.fromPage(geneticResourceDao.search(query, PageRequest.of(page.orElse(0), PAGE_SIZE)));
public AggregatedPageDTO<GeneticResource> search(@RequestParam("query") String query,
@RequestParam("agg") Optional<Boolean> agg,
@RequestParam("page") Optional<Integer> page) {
boolean aggregate = agg.orElse(false);
return AggregatedPageDTO.fromPage(geneticResourceDao.search(query,
aggregate,
PageRequest.of(page.orElse(0), PAGE_SIZE)));
}
}
......@@ -2,12 +2,15 @@ package fr.inra.urgi.rare.dao;
import static org.assertj.core.api.Assertions.assertThat;
import java.util.Arrays;
import java.util.Collections;
import java.util.function.BiConsumer;
import fr.inra.urgi.rare.config.ElasticSearchConfig;
import fr.inra.urgi.rare.domain.GeneticResource;
import fr.inra.urgi.rare.domain.GeneticResourceBuilder;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.bucket.terms.Terms.Bucket;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
......@@ -16,6 +19,7 @@ import org.springframework.boot.test.autoconfigure.json.JsonTest;
import org.springframework.context.annotation.Import;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Pageable;
import org.springframework.data.elasticsearch.core.aggregation.AggregatedPage;
import org.springframework.test.context.TestPropertySource;
import org.springframework.test.context.junit.jupiter.SpringExtension;
......@@ -139,7 +143,7 @@ class GeneticResourceDaoTest {
GeneticResource geneticResource = new GeneticResourceBuilder().build();
geneticResourceDao.save(geneticResource);
assertThat(geneticResourceDao.search(geneticResource.getId(), firstPage).getContent()).isEmpty();
assertThat(geneticResourceDao.search(geneticResource.getId(), false, firstPage).getContent()).isEmpty();
}
@Test
......@@ -148,7 +152,7 @@ class GeneticResourceDaoTest {
new GeneticResourceBuilder().withDataURL("foo bar baz").withPortalURL("foo bar baz").build();
geneticResourceDao.save(geneticResource);
assertThat(geneticResourceDao.search("bar", firstPage).getContent()).isEmpty();
assertThat(geneticResourceDao.search("bar", false, firstPage).getContent()).isEmpty();
}
private void shouldSearch(BiConsumer<GeneticResourceBuilder, String> config) {
......@@ -158,8 +162,58 @@ class GeneticResourceDaoTest {
geneticResourceDao.save(geneticResource);
assertThat(geneticResourceDao.search("bar", firstPage).getContent()).hasSize(1);
assertThat(geneticResourceDao.search("bing", firstPage).getContent()).isEmpty();
AggregatedPage<GeneticResource> result = geneticResourceDao.search("bar", false, firstPage);
assertThat(result.getContent()).hasSize(1);
assertThat(result.getAggregations()).isNull();
result = geneticResourceDao.search("bing", false, firstPage);
assertThat(result.getContent()).isEmpty();
}
@Test
public void shouldSearchAndAggregate() {
GeneticResource geneticResource1 = new GeneticResourceBuilder()
.withId("r1")
.withName("foo")
.withDomain("Plantae")
.withBiotopeType(Arrays.asList("Biotope", "Human host"))
.withMaterialType(Arrays.asList("Specimen", "DNA"))
.withCountryOfOrigin("France")
.build();
GeneticResource geneticResource2 = new GeneticResourceBuilder()
.withId("r2")
.withName("bar foo")
.withDomain("Fungi")
.withBiotopeType(Arrays.asList("Biotope"))
.withMaterialType(Arrays.asList("DNA"))
.withCountryOfOrigin("France")
.build();
geneticResourceDao.saveAll(Arrays.asList(geneticResource1, geneticResource2));
AggregatedPage<GeneticResource> result = geneticResourceDao.search("foo", true, firstPage);
assertThat(result.getContent()).hasSize(2);
Terms domain = result.getAggregations().get(RareAggregation.DOMAIN.getName());
assertThat(domain.getName()).isEqualTo(RareAggregation.DOMAIN.getName());
assertThat(domain.getBuckets()).extracting(Bucket::getKeyAsString).containsOnly("Plantae", "Fungi");
assertThat(domain.getBuckets()).extracting(Bucket::getDocCount).containsOnly(1L);
Terms biotopeType = result.getAggregations().get(RareAggregation.BIOTOPE.getName());
assertThat(biotopeType.getName()).isEqualTo(RareAggregation.BIOTOPE.getName());
assertThat(biotopeType.getBuckets()).extracting(Bucket::getKeyAsString).containsExactly("Biotope", "Human host");
assertThat(biotopeType.getBuckets()).extracting(Bucket::getDocCount).containsExactly(2L, 1L);
Terms materialType = result.getAggregations().get(RareAggregation.MATERIAL.getName());
assertThat(materialType.getName()).isEqualTo(RareAggregation.MATERIAL.getName());
assertThat(materialType.getBuckets()).extracting(Bucket::getKeyAsString).containsExactly("DNA", "Specimen");
assertThat(materialType.getBuckets()).extracting(Bucket::getDocCount).containsExactly(2L, 1L);
Terms countryOfOrigin = result.getAggregations().get(RareAggregation.COUNTRY_OF_ORIGIN.getName());
assertThat(countryOfOrigin.getName()).isEqualTo(RareAggregation.COUNTRY_OF_ORIGIN.getName());
assertThat(countryOfOrigin.getBuckets()).extracting(Bucket::getKeyAsString).containsExactly("France");
assertThat(countryOfOrigin.getBuckets()).extracting(Bucket::getDocCount).containsExactly(2L);
}
}
package fr.inra.urgi.rare.search;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.bucket.terms.Terms.Bucket;
/**
* A mock implementation of {@link Bucket}
* @author JB Nizet
*/
public final class MockBucket implements Bucket {
private final String key;
private final long docCount;
public MockBucket(String key, long docCount) {
this.key = key;
this.docCount = docCount;
}
@Override
public String getKeyAsString() {
return this.key;
}
@Override
public Object getKey() {
return this.key;
}
@Override
public long getDocCount() {
return this.docCount;
}
@Override
public Number getKeyAsNumber() {
throw new UnsupportedOperationException();
}
@Override
public long getDocCountError() {
throw new UnsupportedOperationException();
}
@Override
public Aggregations getAggregations() {
throw new UnsupportedOperationException();
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
throw new UnsupportedEncodingException();
}
}
package fr.inra.urgi.rare.search;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
/**
* A mock implementation of a Terms aggregation
* @author JB Nizet
*/
public final class MockTermsAggregation implements Terms {
private final String name;
private final List<Bucket> buckets;
public MockTermsAggregation(String name, List<Bucket> buckets) {
this.name = name;
this.buckets = buckets;
}
@Override
public String getName() {
return this.name;
}
@Override
public List<? extends Bucket> getBuckets() {
return buckets;
}
@Override
public Bucket getBucketByKey(String term) {
throw new UnsupportedOperationException();
}
@Override
public long getDocCountError() {
throw new UnsupportedOperationException();
}
@Override
public long getSumOfOtherDocCounts() {
throw new UnsupportedOperationException();
}
@Override
public String getType() {
throw new UnsupportedOperationException();
}
@Override
public Map<String, Object> getMetaData() {
throw new UnsupportedOperationException();
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
throw new UnsupportedOperationException();
}
}
......@@ -6,19 +6,21 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
import java.util.Arrays;
import java.util.Collections;
import fr.inra.urgi.rare.config.SecurityConfig;
import fr.inra.urgi.rare.dao.GeneticResourceDao;
import fr.inra.urgi.rare.domain.GeneticResource;
import fr.inra.urgi.rare.domain.GeneticResourceBuilder;