package com.kidgrow.zuul.auth;
import cn.hutool.core.collection.CollectionUtil;
import com.kidgrow.common.constant.SecurityConstants;
import com.kidgrow.common.model.SysOrganization;
import com.kidgrow.common.model.SysUser;
import lombok.SneakyThrows;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.web.server.WebFilterExchange;
import org.springframework.security.web.server.authentication.ServerAuthenticationSuccessHandler;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
import java.net.URLEncoder;
import java.util.List;
/**
* 石家庄喜高科技有限责任公司 版权所有 © Copyright 2020
*
* @Description: 认证成功处理类
* @Project:
* @CreateDate: Created in 2020/2/22 22:44
* @Author: liuke
*/
public class Oauth2AuthSuccessHandler implements ServerAuthenticationSuccessHandler {
@SneakyThrows
@Override
public Mono onAuthenticationSuccess(WebFilterExchange webFilterExchange, Authentication authentication) {
MultiValueMap headerValues = new LinkedMultiValueMap(4);
Object principal = authentication.getPrincipal();
String tenantId="";
//客户端模式只返回一个clientId
if (principal instanceof SysUser) {
SysUser user = (SysUser)authentication.getPrincipal();
headerValues.add(SecurityConstants.USER_ID_HEADER, String.valueOf(user.getId()));
headerValues.add(SecurityConstants.USER_HEADER, user.getUsername());
List organizations = (List)user.getOrganizations();
//如果有组织架构
if(organizations!=null && organizations.size()==2){
headerValues.add(SecurityConstants.USER_ORG_ID_HEADER,String.valueOf(organizations.get(0).getId()));
headerValues.add(SecurityConstants.USER_ORG_NAME_HEADER, URLEncoder.encode(organizations.get(0).getOrgName(),"UTF-8"));
headerValues.add(SecurityConstants.USER_DEP_ID_HEADER,String.valueOf(organizations.get(1).getId()));
headerValues.add(SecurityConstants.USER_DEP_NAME_HEADER,URLEncoder.encode(organizations.get(1).getOrgName(),"UTF-8"));
tenantId=String.valueOf(organizations.get(0).getId());
}
}
OAuth2Authentication oauth2Authentication = (OAuth2Authentication)authentication;
String clientId = oauth2Authentication.getOAuth2Request().getClientId();
//保存租户id,租户id根据业务尽进行替换
switch (clientId){
case "hospital":
tenantId=tenantId;
break;
case "webApp":
tenantId="webApp";
break;
default:
tenantId=clientId;
break;
}
headerValues.add(SecurityConstants.TENANT_HEADER, tenantId);
headerValues.add(SecurityConstants.CLIENT_HEADER, clientId);
headerValues.add(SecurityConstants.ROLE_HEADER, CollectionUtil.join(authentication.getAuthorities(), ","));
ServerWebExchange exchange = webFilterExchange.getExchange();
ServerHttpRequest serverHttpRequest = exchange.getRequest().mutate()
.headers(h -> {
h.addAll(headerValues);
})
.build();
ServerWebExchange build = exchange.mutate().request(serverHttpRequest).build();
return webFilterExchange.getChain().filter(build);
}
}