/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.action;

import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executor;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.injection.guice.Inject;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.inference.ModelAliasMetadata;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.action.TransportStartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;

public class TransportGetTrainedModelsAction
extends HandledTransportAction<GetTrainedModelsAction.Request, GetTrainedModelsAction.Response> {
    private final TrainedModelProvider provider;
    private final ClusterService clusterService;
    private final Client client;

    @Inject
    public TransportGetTrainedModelsAction(TransportService transportService, ActionFilters actionFilters, ClusterService clusterService, Client client, TrainedModelProvider trainedModelProvider) {
        super("cluster:monitor/xpack/ml/inference/get", transportService, actionFilters, GetTrainedModelsAction.Request::new, (Executor)EsExecutors.DIRECT_EXECUTOR_SERVICE);
        this.provider = trainedModelProvider;
        this.clusterService = clusterService;
        this.client = client;
    }

    protected void doExecute(Task task, GetTrainedModelsAction.Request request, ActionListener<GetTrainedModelsAction.Response> listener) {
        TaskId parentTaskId = new TaskId(this.clusterService.localNode().getId(), task.getId());
        this.provider.expandIds(request.getResourceId(), request.isAllowNoResources(), request.getPageParams(), new HashSet<String>(request.getTags()), ModelAliasMetadata.fromState((ClusterState)this.clusterService.state()), parentTaskId, Collections.emptySet(), (ActionListener<Tuple<Long, Map<String, Set<String>>>>)listener.delegateFailureAndWrap((delegate, totalAndIds) -> {
            GetTrainedModelsAction.Response.Builder responseBuilder = GetTrainedModelsAction.Response.builder();
            responseBuilder.setTotalCount(((Long)totalAndIds.v1()).longValue());
            if (((Map)totalAndIds.v2()).isEmpty()) {
                delegate.onResponse((Object)responseBuilder.build());
                return;
            }
            if (request.getIncludes().isIncludeModelDefinition() && ((Map)totalAndIds.v2()).size() > 1) {
                delegate.onFailure((Exception)((Object)ExceptionsHelper.badRequestException((String)"Getting model definition is not supported when getting more than one model", (Object[])new Object[0])));
                return;
            }
            if (request.getIncludes().isIncludeDefinitionStatus() && ((Map)totalAndIds.v2()).size() > 1) {
                delegate.onFailure((Exception)((Object)ExceptionsHelper.badRequestException((String)"Getting the model download status is not supported when getting more than one model", (Object[])new Object[0])));
                return;
            }
            ActionListener getModelDefinitionStatusListener = delegate.delegateFailureAndWrap((delegate2, configs) -> {
                if (!request.getIncludes().isIncludeDefinitionStatus()) {
                    delegate2.onResponse((Object)responseBuilder.setModels(configs).build());
                    return;
                }
                assert (configs.size() <= 1);
                if (configs.isEmpty()) {
                    delegate2.onResponse((Object)responseBuilder.setModels(configs).build());
                    return;
                }
                if (((TrainedModelConfig)configs.get(0)).getModelType() != TrainedModelType.PYTORCH) {
                    delegate2.onFailure((Exception)((Object)ExceptionsHelper.badRequestException((String)"Definition status is only relevant to PyTorch model types", (Object[])new Object[0])));
                    return;
                }
                TransportStartTrainedModelDeploymentAction.checkFullModelDefinitionIsPresent(new OriginSettingClient(this.client, "ml"), (TrainedModelConfig)configs.get(0), false, null, (ActionListener<Tuple<String, Long>>)delegate2.delegateFailureAndWrap((l, modelIdAndLength) -> {
                    ((TrainedModelConfig)configs.get(0)).setFullDefinition((Long)modelIdAndLength.v2() > 0L);
                    l.onResponse((Object)responseBuilder.setModels(configs).build());
                }));
            });
            if (request.getIncludes().isIncludeModelDefinition()) {
                Map.Entry modelIdAndAliases = ((Map)totalAndIds.v2()).entrySet().iterator().next();
                this.provider.getTrainedModel((String)modelIdAndAliases.getKey(), (Set)modelIdAndAliases.getValue(), request.getIncludes(), parentTaskId, (ActionListener<TrainedModelConfig>)getModelDefinitionStatusListener.delegateFailureAndWrap((l, config) -> l.onResponse(Collections.singletonList(config))));
            } else {
                this.provider.getTrainedModels((Map)totalAndIds.v2(), request.getIncludes(), request.isAllowNoResources(), parentTaskId, (ActionListener<List<TrainedModelConfig>>)getModelDefinitionStatusListener);
            }
        }));
    }
}

