#include<bits/stdc++.h>
using namespace std;
#define ll long long
ll c[5022],dp[10001][3],r[10001],home[10001];
ll n,m,cnt;
struct node{
	ll next,to;
}e[20001];
void add(ll a,ll b){
	e[++cnt].to=n;
	e[cnt].next=home[a];
	home[a]=cnt;
}
void dfs(ll p,ll fa){
	if(p<=n){
		dp[p][0]=0x3f3f3f3f3f3f;
		dp[p][1]=c[p]==0?1:0x3f3f3f3f3f3f;
		dp[p][2]=c[p]==1?1:0x3f3f3f3f3f3f;
		return ;
	}
	for(ll i=home[p];i;i=e[i].next){
		ll t=e[i].to;
		if(t==fa) continue;
		dfs(t,p);
		dp[p][0]+=min(dp[t][1],dp[t][2]);
		dp[p][1]+=dp[t][1]-1;
		dp[p][2]+=dp[t][2]-1;
	}
	dp[p][1]++;
	dp[p][2]++;
}
int main(){
	scanf("%lld %lld",&m,&n);
	queue<ll> q;
	for(ll i=1;i<=n;i++) scanf("%lld",&c[i]);
	for(ll i=1;i<m;i++){
		ll a,b;
		scanf("%lld %lld",&a,&b);
		add(a,b);
		add(b,a);
		r[a]++;
		r[b]++;
		if(r[a]==2) q.push(a);
		if(r[b]==2) q.push(b);
	}
	ll ans=0x3f3f3f3f3f3f3f3f;
	while(!q.empty()){
		ll root=q.front();
		q.pop();
		for(ll i=1;i<=m;i++){
			dp[i][1]=dp[i][0]=dp[i][2]=0;
		}
		dfs(root,0);
		ans=min(ans,dp[root][0]);
	}
	printf("%lld",ans);
	return 0;
}